[Mlir-commits] [mlir] [mlir][linalg] Add scalable vectorisation for depthwise convolutions (PR #81625)

Andrzej WarzyƄski llvmlistbot at llvm.org
Wed Feb 14 14:13:13 PST 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/81625

>From e4755c1d2c597fb212ff8e009a56cfa8c7fc38b8 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 8 Feb 2024 12:45:40 +0000
Subject: [PATCH 1/2] [mlir][linalg] Add scalable vectorisation for depthwise
 convolutions

This patch adds support for scalable vectorisation of depthwise 1D HWC
convolutions,`linalg.depthwise_conv_1d_nwc_wc`. This is implemented by
adding support for masking.

Two major assumptions are made:
  * only the channel dimension can be scalable/dynamic (i.e. the
    trailing dim),
  * when specifying vector sizes to use in the vectoriser, only the size
    corresponding to the channel dim is effectively used (other dims are
    inferred from the context).

In terms of scalable vectorisation, this should be sufficient that cover
all practical cases (i.e. making arbitrary dim scalable wouldn't make
much sense). As for more generic cases with dynamic shapes (e.g. w or n
dims being dynamic), more work would be needed. In particular, one would
have to consider the filter and input/ouput tensors separately. However,
it's not clear whether that would be of any use in practice.
---
 .../Linalg/Transforms/Vectorization.cpp       | 131 +++++++++++---
 .../Linalg/vectorize-conv-scalable.mlir       | 161 ++++++++++++++++++
 2 files changed, 273 insertions(+), 19 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2bd6929fea6142..aed96be5b6da00 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -54,6 +54,7 @@ using namespace mlir::linalg;
 /// Try to vectorize `convOp` as a convolution.
 static FailureOr<Operation *>
 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
+                     ArrayRef<int64_t> inputVecSizes = {},
                      bool flatten1DDepthwiseConv = false);
 
 /// Return the unique instance of OpType in `block` if it is indeed unique.
@@ -1609,6 +1610,19 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
 }
 
 static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
+  // Support dynamic shapes in 1D depthwise convolution, but only in the
+  // _channel_ dimension. That's exclusively to support scalable vectorisation.
+  if (auto conv = dyn_cast<linalg::DepthwiseConv1DNwcWcOp>(op.getOperation())) {
+    auto lhsShaped = op.getDpsInputOperand(0)->get();
+    ArrayRef<int64_t> lhsShape =
+        dyn_cast<ShapedType>(lhsShaped.getType()).getShape();
+    auto shapeWithoutCh = lhsShape.drop_back(1);
+    if (ShapedType::isDynamicShape(shapeWithoutCh))
+      return failure();
+
+    return success();
+  }
+
   // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
   // linalg.copy ops and ops that implement ContractionOpInterface for now.
   if (!isElementwise(op) &&
@@ -1789,7 +1803,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
 
   // Only element-wise ops supported in the presence of scalable dims.
   auto linalgOp = dyn_cast<LinalgOp>(op);
-  return success(linalgOp && isElementwise(linalgOp));
+  return success(linalgOp && (isElementwise(linalgOp) ||
+                              isa<linalg::DepthwiseConv1DNwcWcOp>(op)));
 }
 
 LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -1871,7 +1886,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
             // features. Will require stride/dilation attributes inference.
             if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
               FailureOr<Operation *> convOr = vectorizeConvolution(
-                  rewriter, linalgOp, flatten1DDepthwiseConv);
+                  rewriter, linalgOp, inputVectorSizes, flatten1DDepthwiseConv);
               if (succeeded(convOr)) {
                 llvm::append_range(results, (*convOr)->getResults());
                 return success();
@@ -2697,6 +2712,7 @@ struct Conv1DGenerator
         return;
       break;
     }
+    hasTensorSemantics = linalgOp.hasPureTensorSemantics();
     // The op is now known to be valid.
     valid = true;
   }
@@ -3000,13 +3016,21 @@ struct Conv1DGenerator
   /// kw is always unrolled.
   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
   /// > 1.
-  FailureOr<Operation *> depthwiseConv(bool flatten) {
+  FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
+                                       bool flatten) {
     if (!valid)
       return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
 
+    bool scalableChDim = false;
     int64_t nSize, wSize, cSize, kwSize;
     // kernel{kw, c}
     bindShapeDims(rhsShapedType, kwSize, cSize);
+    // Dynamic channel size implies scalable vectorisation
+    if (ShapedType::isDynamic(cSize)) {
+      assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
+      cSize = channelDimVecSize;
+      scalableChDim = true;
+    }
     // out{n, w, c}
     bindShapeDims(resShapedType, nSize, wSize);
 
@@ -3027,20 +3051,74 @@ struct Conv1DGenerator
          //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
          ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
          cSize},
-        lhsEltType);
-    VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
-    VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
+        lhsEltType, {false, false, scalableChDim});
+    VectorType rhsType =
+        VectorType::get({kwSize, cSize}, rhsEltType,
+                        /*scalableDims=*/{false, scalableChDim});
+    VectorType resType =
+        VectorType::get({nSize, wSize, cSize}, resEltType,
+                        /*scalableDims=*/{false, false, scalableChDim});
+
+    // Masks the input xfer Op along the channel dim, iff the corresponding
+    // scalable flag is set.
+    auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
+                               ArrayRef<bool> scalableDims,
+                               Operation *opToMask) {
+      bool scalableChDim = scalableDims.back();
+      if (!scalableChDim)
+        return opToMask;
+
+      auto maskType =
+          VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
+
+      SmallVector<OpFoldResult> mixedSourceDims =
+          hasTensorSemantics
+              ? TypeSwitch<Operation *, SmallVector<OpFoldResult>>(opToMask)
+                    .Case<vector::TransferReadOp>([&](auto readOp) {
+                      return tensor::getMixedSizes(rewriter, loc,
+                                                   readOp.getSource());
+                    })
+                    .Case<vector::TransferWriteOp>([&](auto writeOp) {
+                      return tensor::getMixedSizes(rewriter, loc,
+                                                   writeOp.getOperand(1));
+                    })
+              : TypeSwitch<Operation *, SmallVector<OpFoldResult>>(opToMask)
+                    .Case<vector::TransferReadOp>([&](auto readOp) {
+                      return memref::getMixedSizes(rewriter, loc,
+                                                   readOp.getSource());
+                    })
+                    .Case<vector::TransferWriteOp>([&](auto writeOp) {
+                      return memref::getMixedSizes(rewriter, loc,
+                                                   writeOp.getOperand(1));
+                    });
+
+      Value maskOp =
+          rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+
+      return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
+    };
 
     // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
     // 0].
     Value lhs = rewriter.create<vector::TransferReadOp>(
         loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
+    auto maybeMaskedLHS = maybeMaskXferOp(
+        lhsType.getShape(),
+        /*scalableDims=*/{false, false, scalableChDim}, lhs.getDefiningOp());
+
     // Read rhs slice of size {kw, c} @ [0, 0].
     Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
                                                         ValueRange{zero, zero});
+    auto maybeMaskedRHS = maybeMaskXferOp(
+        rhsType.getShape(),
+        /*scalableDims=*/{false, scalableChDim}, rhs.getDefiningOp());
+
     // Read res slice of size {n, w, c} @ [0, 0, 0].
     Value res = rewriter.create<vector::TransferReadOp>(
         loc, resType, resShaped, ValueRange{zero, zero, zero});
+    auto maybeMaskedRES = maybeMaskXferOp(
+        resType.getShape(),
+        /*scalableDims=*/{false, false, scalableChDim}, res.getDefiningOp());
 
     //===------------------------------------------------------------------===//
     // Begin vector-only rewrite part
@@ -3055,7 +3133,7 @@ struct Conv1DGenerator
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
         lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
-            loc, lhs,
+            loc, maybeMaskedLHS->getResult(0),
             /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
             inOutSliceSizes, inOutStrides));
       }
@@ -3063,12 +3141,13 @@ struct Conv1DGenerator
     // Extract rhs slice of size {c} @ [kw].
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       rhsVals.push_back(rewriter.create<vector::ExtractOp>(
-          loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
+          loc, maybeMaskedRHS->getResult(0),
+          /*offsets=*/ArrayRef<int64_t>{kw}));
     }
     // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
     for (int64_t w = 0; w < wSize; w += wSizeStep) {
       resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, res,
+          loc, maybeMaskedRES->getResult(0),
           /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
           inOutStrides));
     }
@@ -3107,6 +3186,7 @@ struct Conv1DGenerator
     // Its possible we failed to create the Fma.
     if (!llvm::all_of(resVals, [](Value v) { return v; })) {
       // Manually revert (in reverse order) to avoid leaving a bad IR state.
+      // TODO: Replace with maybeMasked
       for (auto &collection :
            {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
         for (Value v : collection)
@@ -3117,8 +3197,8 @@ struct Conv1DGenerator
     // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
     // This does not depend on kw.
     for (int64_t w = 0; w < wSize; w += wSizeStep) {
-      res = rewriter.create<vector::InsertStridedSliceOp>(
-          loc, resVals[w], res,
+      maybeMaskedRES = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, resVals[w], maybeMaskedRES->getResult(0),
           /*offsets=*/ArrayRef<int64_t>{0, w, 0},
           /*strides=*/ArrayRef<int64_t>{1, 1, 1});
     }
@@ -3127,10 +3207,12 @@ struct Conv1DGenerator
     //===------------------------------------------------------------------===//
 
     // Write back res slice of size {n, w, c} @ [0, 0, 0].
-    return rewriter
-        .create<vector::TransferWriteOp>(loc, res, resShaped,
-                                         ValueRange{zero, zero, zero})
-        .getOperation();
+    Operation *resOut = rewriter.create<vector::TransferWriteOp>(
+        loc, maybeMaskedRES->getResult(0), resShaped,
+        ValueRange{zero, zero, zero});
+    return maybeMaskXferOp(resType.getShape(),
+                           /*scalableDims=*/{false, false, scalableChDim},
+                           resOut);
   }
 
   /// Lower:
@@ -3171,8 +3253,9 @@ struct Conv1DGenerator
     if (!lhs || !rhs)
       return nullptr;
 
-    if (isa<FloatType>(resTy.getElementType()))
+    if (isa<FloatType>(resTy.getElementType())) {
       return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
+    }
 
     auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
     return rewriter.create<arith::AddIOp>(loc, mul, res);
@@ -3268,7 +3351,8 @@ struct Conv1DGenerator
 
   /// Entry point that transposes into the common form:
   ///   {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
-  FailureOr<Operation *> generateDilatedConv(bool flatten = false) {
+  FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
+                                             bool flatten = false) {
     AffineExpr n, w, c, kw;
     bindDims(ctx, n, w, c, kw);
     if (!iters({Par(), Par(), Par(), Red()}))
@@ -3279,7 +3363,7 @@ struct Conv1DGenerator
     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
                 /*rhsIndex*/ {kw, c},
                 /*resIndex*/ {n, w, c}}))
-      return depthwiseConv(flatten);
+      return depthwiseConv(vecChDimSize, flatten);
 
     return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
   }
@@ -3291,6 +3375,7 @@ struct Conv1DGenerator
   StringAttr redOp;
   StringAttr poolExtOp;
   bool isPoolExt = false;
+  bool hasTensorSemantics = false;
   int strideW, dilationW;
   Value lhsShaped, rhsShaped, resShaped;
   ShapedType lhsShapedType, rhsShapedType, resShapedType;
@@ -3346,6 +3431,7 @@ struct Conv1DGenerator
 // TODO: extend the generic vectorization to support windows and drop this.
 static FailureOr<Operation *>
 vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
+                     ArrayRef<int64_t> inputVecSizes,
                      bool flatten1DDepthwiseConv) {
   // The ConvolutionOpInterface gives us guarantees of existence for
   // strides/dilations. However, we do not need to rely on those, we can simply
@@ -3371,7 +3457,14 @@ vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
   res = e.generateNcwPooling();
   if (succeeded(res))
     return res;
-  return e.generateDilatedConv(flatten1DDepthwiseConv);
+
+  uint64_t vecChDimSize = ShapedType::kDynamic;
+  if (!inputVecSizes.empty()) {
+    // Only use the input vector size corresponding to the channel dim. Other
+    // vector dims will be inferred from the Ops.
+    vecChDimSize = inputVecSizes[2];
+  }
+  return e.generateDilatedConv(vecChDimSize, flatten1DDepthwiseConv);
 }
 
 struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
diff --git a/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir b/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir
new file mode 100644
index 00000000000000..d4b3574451c2bf
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir
@@ -0,0 +1,161 @@
+// RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s
+
+func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(%input: tensor<1x8x?xi8>,
+                                                   %filter: tensor<1x?xi8>,
+                                                   %output: tensor<1x8x?xi8>) -> (tensor<1x8x?xi8>) {
+  %res = linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<1> : vector<1xi64>,
+    strides = dense<1> : vector<1xi64>}
+    ins(%input, %filter : tensor<1x8x?xi8>, tensor<1x?xi8>)
+    outs(%output : tensor<1x8x?xi8>) -> tensor<1x8x?xi8>
+  return %res : tensor<1x8x?xi8>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [1, 8, [4], 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(
+// CHECK-SAME:      %[[INPUT:.*]]: tensor<1x8x?xi8>,
+// CHECK-SAME:      %[[FILTER:.*]]: tensor<1x?xi8>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: tensor<1x8x?xi8>) -> tensor<1x8x?xi8> {
+
+// CHECK-DAG:       arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = tensor.dim %[[FILTER]], %[[VAL_5]] : tensor<1x?xi8>
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C0_I8:.*]] = arith.constant 0 : i8
+
+/// Create a mask for the input tensor
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[CH_DIM_SIZE_INPUT:.*]] = tensor.dim %[[INPUT]], %[[C2]] : tensor<1x8x?xi8>
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C8:.*]] = arith.constant 8 : index
+// CHECK:           %[[MASK_IN:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_SIZE_INPUT]] : vector<1x8x[4]xi1>
+/// Read the input tensor
+// CHECK:           %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[C0_I8]] : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
+
+/// Create a mask for the filter tensor
+// CHECK:           %[[C0_I8_1:.*]] = arith.constant 0 : i8
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[CH_DIM_SIZE_FLT:.*]] = tensor.dim %[[FILTER]], %[[C1]] : tensor<1x?xi8>
+// CHECK:           %[[C1_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[MASK_FLT:.*]] = vector.create_mask %[[C1_1]], %[[CH_DIM_SIZE_FLT]] : vector<1x[4]xi1>
+/// Read the filter tensor
+// CHECK:           %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[C0_I8_1]] : tensor<1x?xi8>, vector<1x[4]xi8> } : vector<1x[4]xi1> -> vector<1x[4]xi8>
+
+/// Create a mask for the output tensor
+// CHECK:           %[[VAL_22:.*]] = arith.constant 0 : i8
+// CHECK:           %[[VAL_23:.*]] = arith.constant 2 : index
+// CHECK:           %[[CH_DIM_SIZE_OUT:.*]] = tensor.dim %[[OUTPUT]], %[[VAL_23]] : tensor<1x8x?xi8>
+// CHECK:           %[[VAL_25:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_26:.*]] = arith.constant 8 : index
+// CHECK:           %[[MASK_OUT:.*]] = vector.create_mask %[[VAL_25]], %[[VAL_26]], %[[CH_DIM_SIZE_OUT]] : vector<1x8x[4]xi1>
+/// Read the output tensor
+// CHECK:           %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[VAL_22]] : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
+
+/// Convolution
+// CHECK:           %[[VEC_IN_0:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x[4]xi8> to vector<1x8x[4]xi8>
+// CHECK:           %[[VEC_FLT_0:.*]] = vector.extract %[[VEC_FLT]][0] : vector<[4]xi8> from vector<1x[4]xi8>
+// CHECK:           %[[VEC_OUT_0:.*]] = vector.extract_strided_slice %[[VEC_OUT]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x[4]xi8> to vector<1x8x[4]xi8>
+// CHECK:           %[[FLT_B:.*]] = vector.broadcast %[[VEC_FLT_0]] : vector<[4]xi8> to vector<1x8x[4]xi8>
+// CHECK:           %[[MULI:.*]] = arith.muli %[[VEC_IN_0]], %[[FLT_B]] : vector<1x8x[4]xi8>
+// CHECK:           %[[ADDI:.*]] = arith.addi %[[MULI]], %[[VEC_OUT_0]] : vector<1x8x[4]xi8>
+// CHECK:           %[[VEC_OUT_1:.*]] = vector.insert_strided_slice %[[ADDI]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x8x[4]xi8> into vector<1x8x[4]xi8>
+
+/// Create a mask for the output tensor
+// CHECK:           %[[VAL_36:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_37:.*]] = tensor.dim %[[OUTPUT]], %[[VAL_36]] : tensor<1x8x?xi8>
+// CHECK:           %[[VAL_38:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_39:.*]] = arith.constant 8 : index
+// CHECK:           %[[MASK_OUT:.*]] = vector.create_mask %[[VAL_38]], %[[VAL_39]], %[[VAL_37]] : vector<1x8x[4]xi1>
+
+/// Write the output tensor
+// CHECK:           vector.mask %[[MASK_OUT]] { vector.transfer_write %[[VEC_OUT_1]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<1x8x[4]xi8>, tensor<1x8x?xi8> } : vector<1x8x[4]xi1> -> tensor<1x8x?xi8>
+
+
+// -----
+
+func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3x5x?xf32>,
+                                                                %filter: memref<2x?xf32>,
+                                                                %output: memref<3x2x?xf32>) {
+  linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : memref<3x5x?xf32>, memref<2x?xf32>)
+    outs(%output : memref<3x2x?xf32>)
+  return
+}
+
+// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(
+// CHECK-SAME:      %[[INPUT:.*]]: memref<3x5x?xf32>,
+// CHECK-SAME:      %[[FILTER:.*]]: memref<2x?xf32>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: memref<3x2x?xf32>) {
+
+// CHECK:           %[[VAL_3:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_6:.*]] = memref.dim %[[FILTER]], %[[VAL_5]] : memref<2x?xf32>
+// CHECK:           %[[VAL_7:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_8:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
+
+/// Create a mask for the input tensor
+// CHECK:           %[[VAL_10:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_11:.*]] = memref.dim %[[INPUT]], %[[VAL_10]] : memref<3x5x?xf32>
+// CHECK:           %[[VAL_12:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_13:.*]] = arith.constant 5 : index
+// CHECK:           %[[MASK_IN:.*]] = vector.create_mask %[[VAL_12]], %[[VAL_13]], %[[VAL_11]] : vector<3x4x[4]xi1>
+/// Read the input tensor
+// CHECK:           %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[VAL_9]] : memref<3x5x?xf32>, vector<3x4x[4]xf32> } : vector<3x4x[4]xi1> -> vector<3x4x[4]xf32>
+
+/// Create a mask for the filter tensor
+// CHECK:           %[[VAL_16:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_17:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_18:.*]] = memref.dim %[[FILTER]], %[[VAL_17]] : memref<2x?xf32>
+// CHECK:           %[[VAL_19:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_20:.*]] = vector.create_mask %[[VAL_19]], %[[VAL_18]] : vector<2x[4]xi1>
+/// Read the filter tensor
+// CHECK:           %[[VEC_FLT:.*]] = vector.mask %[[VAL_20]] { vector.transfer_read %[[FILTER]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_16]] : memref<2x?xf32>, vector<2x[4]xf32> } : vector<2x[4]xi1> -> vector<2x[4]xf32>
+
+/// Create a mask for the output tensor
+// CHECK:           %[[VAL_22:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_23:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_24:.*]] = memref.dim %[[OUTPUT]], %[[VAL_23]] : memref<3x2x?xf32>
+// CHECK:           %[[VAL_25:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_26:.*]] = arith.constant 2 : index
+// CHECK:           %[[MASK_OUT:.*]] = vector.create_mask %[[VAL_25]], %[[VAL_26]], %[[VAL_24]] : vector<3x2x[4]xi1>
+/// Read the output tensor
+// CHECK:           %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[VAL_22]] : memref<3x2x?xf32>, vector<3x2x[4]xf32> } : vector<3x2x[4]xi1> -> vector<3x2x[4]xf32>
+
+/// Convolution
+// CHECK:           %[[VEC_IN_0:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
+// CHECK:           %[[VEC_IN_1:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
+// CHECK:           %[[VEC_FLT_0:.*]] = vector.extract %[[VEC_FLT]][0] : vector<[4]xf32> from vector<2x[4]xf32>
+// CHECK:           %[[VEC_FLT_1:.*]] = vector.extract %[[VEC_FLT]][1] : vector<[4]xf32> from vector<2x[4]xf32>
+// CHECK:           %[[VEC_OUT_0:.*]] = vector.extract_strided_slice %[[VEC_OUT]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x2x[4]xf32> to vector<3x2x[4]xf32>
+// CHECK:           %[[VEC_FLT_0_B:.*]] = vector.broadcast %[[VEC_FLT_0]] : vector<[4]xf32> to vector<3x2x[4]xf32>
+// CHECK:           %[[FMA_1:.*]] = vector.fma %[[VEC_IN_0]], %[[VEC_FLT_0_B]], %[[VEC_OUT_0]] : vector<3x2x[4]xf32>
+// CHECK:           %[[VEC_FLT_1_B:.*]] = vector.broadcast %[[VEC_FLT_1]] : vector<[4]xf32> to vector<3x2x[4]xf32>
+// CHECK:           %[[FMA_2:.*]] = vector.fma %[[VEC_IN_1]], %[[VEC_FLT_1_B]], %[[FMA_1]] : vector<3x2x[4]xf32>
+// CHECK:           %[[VEC_OUT_1:.*]] = vector.insert_strided_slice %[[FMA_2]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<3x2x[4]xf32> into vector<3x2x[4]xf32>
+
+/// Create a mask for the output tensor
+// CHECK:           %[[VAL_39:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_40:.*]] = memref.dim %[[OUTPUT]], %[[VAL_39]] : memref<3x2x?xf32>
+// CHECK:           %[[VAL_41:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_42:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_43:.*]] = vector.create_mask %[[VAL_41]], %[[VAL_42]], %[[VAL_40]] : vector<3x2x[4]xi1>
+/// Write the output tensor
+// CHECK:           vector.mask %[[VAL_43]] { vector.transfer_write %[[VEC_OUT_1]], %[[OUTPUT]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]] : vector<3x2x[4]xf32>, memref<3x2x?xf32> } : vector<3x2x[4]xi1>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [3, 2, [4], 2] : !transform.any_op
+    transform.yield
+  }
+}

>From 5f78c47956412c6fbb9b20fedbb5b6d376b5786b Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 14 Feb 2024 22:12:34 +0000
Subject: [PATCH 2/2] fixup! [clang][Driver] Small correction to
 print-runtime-dir

Addressing Cullen's comments
---
 .../Linalg/Transforms/Vectorization.cpp       | 46 +++++++++----------
 .../Linalg/vectorization-unsupported.mlir     | 19 ++++++++
 2 files changed, 40 insertions(+), 25 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index aed96be5b6da00..7f452769f01359 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1610,12 +1610,13 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
 }
 
 static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
-  // Support dynamic shapes in 1D depthwise convolution, but only in the
-  // _channel_ dimension. That's exclusively to support scalable vectorisation.
   if (auto conv = dyn_cast<linalg::DepthwiseConv1DNwcWcOp>(op.getOperation())) {
+    // Support dynamic shapes in 1D depthwise convolution, but only in the
+    // _channel_ dimension. That's exclusively to support scalable
+    // vectorisation.
     auto lhsShaped = op.getDpsInputOperand(0)->get();
     ArrayRef<int64_t> lhsShape =
-        dyn_cast<ShapedType>(lhsShaped.getType()).getShape();
+        cast<ShapedType>(lhsShaped.getType()).getShape();
     auto shapeWithoutCh = lhsShape.drop_back(1);
     if (ShapedType::isDynamicShape(shapeWithoutCh))
       return failure();
@@ -1801,7 +1802,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
   if (!isScalable)
     return success();
 
-  // Only element-wise ops supported in the presence of scalable dims.
+  // Only element-wise and 1d depthwise conv ops supported in the presence of
+  // scalable dims.
   auto linalgOp = dyn_cast<LinalgOp>(op);
   return success(linalgOp && (isElementwise(linalgOp) ||
                               isa<linalg::DepthwiseConv1DNwcWcOp>(op)));
@@ -3051,7 +3053,7 @@ struct Conv1DGenerator
          //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
          ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
          cSize},
-        lhsEltType, {false, false, scalableChDim});
+        lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
     VectorType rhsType =
         VectorType::get({kwSize, cSize}, rhsEltType,
                         /*scalableDims=*/{false, scalableChDim});
@@ -3102,23 +3104,20 @@ struct Conv1DGenerator
     // 0].
     Value lhs = rewriter.create<vector::TransferReadOp>(
         loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
-    auto maybeMaskedLHS = maybeMaskXferOp(
-        lhsType.getShape(),
-        /*scalableDims=*/{false, false, scalableChDim}, lhs.getDefiningOp());
+    auto maybeMaskedLhs = maybeMaskXferOp(
+        lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
 
     // Read rhs slice of size {kw, c} @ [0, 0].
     Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
                                                         ValueRange{zero, zero});
-    auto maybeMaskedRHS = maybeMaskXferOp(
-        rhsType.getShape(),
-        /*scalableDims=*/{false, scalableChDim}, rhs.getDefiningOp());
+    auto maybeMaskedRhs = maybeMaskXferOp(
+        rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
 
     // Read res slice of size {n, w, c} @ [0, 0, 0].
     Value res = rewriter.create<vector::TransferReadOp>(
         loc, resType, resShaped, ValueRange{zero, zero, zero});
-    auto maybeMaskedRES = maybeMaskXferOp(
-        resType.getShape(),
-        /*scalableDims=*/{false, false, scalableChDim}, res.getDefiningOp());
+    auto maybeMaskedRes = maybeMaskXferOp(
+        resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
 
     //===------------------------------------------------------------------===//
     // Begin vector-only rewrite part
@@ -3133,7 +3132,7 @@ struct Conv1DGenerator
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
         lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
-            loc, maybeMaskedLHS->getResult(0),
+            loc, maybeMaskedLhs->getResult(0),
             /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
             inOutSliceSizes, inOutStrides));
       }
@@ -3141,13 +3140,13 @@ struct Conv1DGenerator
     // Extract rhs slice of size {c} @ [kw].
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       rhsVals.push_back(rewriter.create<vector::ExtractOp>(
-          loc, maybeMaskedRHS->getResult(0),
+          loc, maybeMaskedRhs->getResult(0),
           /*offsets=*/ArrayRef<int64_t>{kw}));
     }
     // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
     for (int64_t w = 0; w < wSize; w += wSizeStep) {
       resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, maybeMaskedRES->getResult(0),
+          loc, maybeMaskedRes->getResult(0),
           /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
           inOutStrides));
     }
@@ -3186,7 +3185,6 @@ struct Conv1DGenerator
     // Its possible we failed to create the Fma.
     if (!llvm::all_of(resVals, [](Value v) { return v; })) {
       // Manually revert (in reverse order) to avoid leaving a bad IR state.
-      // TODO: Replace with maybeMasked
       for (auto &collection :
            {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
         for (Value v : collection)
@@ -3197,8 +3195,8 @@ struct Conv1DGenerator
     // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
     // This does not depend on kw.
     for (int64_t w = 0; w < wSize; w += wSizeStep) {
-      maybeMaskedRES = rewriter.create<vector::InsertStridedSliceOp>(
-          loc, resVals[w], maybeMaskedRES->getResult(0),
+      maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, resVals[w], maybeMaskedRes->getResult(0),
           /*offsets=*/ArrayRef<int64_t>{0, w, 0},
           /*strides=*/ArrayRef<int64_t>{1, 1, 1});
     }
@@ -3208,10 +3206,9 @@ struct Conv1DGenerator
 
     // Write back res slice of size {n, w, c} @ [0, 0, 0].
     Operation *resOut = rewriter.create<vector::TransferWriteOp>(
-        loc, maybeMaskedRES->getResult(0), resShaped,
+        loc, maybeMaskedRes->getResult(0), resShaped,
         ValueRange{zero, zero, zero});
-    return maybeMaskXferOp(resType.getShape(),
-                           /*scalableDims=*/{false, false, scalableChDim},
+    return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
                            resOut);
   }
 
@@ -3253,9 +3250,8 @@ struct Conv1DGenerator
     if (!lhs || !rhs)
       return nullptr;
 
-    if (isa<FloatType>(resTy.getElementType())) {
+    if (isa<FloatType>(resTy.getElementType()))
       return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
-    }
 
     auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
     return rewriter.create<arith::AddIOp>(loc, mul, res);
diff --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
index a1a52397386c97..4553dc4fbe414b 100644
--- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
@@ -19,6 +19,25 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @depthwise_conv1d_nwc_wc_dyn_w_dim(%input: memref<3x?x4xf32>, %filter: memref<?x4xf32>, %output: memref<3x?x4xf32>) {
+  // expected-error @+1 {{Attempted to vectorize, but failed}}
+  linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : memref<3x?x4xf32>, memref<?x4xf32>)
+    outs(%output : memref<3x?x4xf32>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [3, 2, 4, 2] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @depthwise_conv1d_nwc_wc_dyn_ch_dim(%input: memref<3x5x?xf32>, %filter: memref<2x?xf32>, %output: memref<3x2x?xf32>) {
   // expected-error @+1 {{Attempted to vectorize, but failed}}
   linalg.depthwise_conv_1d_nwc_wc



More information about the Mlir-commits mailing list