[Mlir-commits] [mlir] 99ff697 - [mlir][Vector] Add support for 1D depthwise conv vectorization

Nicolas Vasilache llvmlistbot at llvm.org
Fri Nov 12 05:15:46 PST 2021


Author: Nicolas Vasilache
Date: 2021-11-12T13:14:09Z
New Revision: 99ff697bf72af978515ecca833337965502d4e63

URL: https://github.com/llvm/llvm-project/commit/99ff697bf72af978515ecca833337965502d4e63
DIFF: https://github.com/llvm/llvm-project/commit/99ff697bf72af978515ecca833337965502d4e63.diff

LOG: [mlir][Vector] Add support for 1D depthwise conv vectorization

At this time the 2 flavors of conv are a little too different to allow significant code sharing and other will likely come up.
so we go the easy route first by duplicating and adapting.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D113758

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorize-convolution.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c65d2a1de869..9f2b798c12f0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1390,16 +1390,25 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
 // Convolution vectorization patterns
 //===----------------------------------------------------------------------===//
 namespace {
-/// Generate a vector implementation for:
+/// Generate a vector implementation for either:
 /// ```
 ///   Op def: (     n,     w,     c,    kw,    f  )
 ///    Iters: ({Par(), Par(), Par(), Red(), Red()})
 ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
 /// ```
 /// kw is unrolled, w is unrolled iff dilationW > 1.
-struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
-  Conv1D_NWC_WCF_Generator(OpBuilder &builder, LinalgOp linalgOp, int strideW,
-                           int dilationW)
+///
+/// or
+///
+/// ```
+///   Op def: (     n,     w,     c,    kw )
+///    Iters: ({Par(), Par(), Par(), Red()})
+///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
+/// ```
+/// kw is unrolled, w is unrolled iff dilationW > 1.
+struct Conv1D_NWC_Generator : public StructuredGenerator<LinalgOp> {
+  Conv1D_NWC_Generator(OpBuilder &builder, LinalgOp linalgOp, int strideW,
+                       int dilationW)
       : StructuredGenerator<LinalgOp>(builder, linalgOp), valid(false),
         strideW(strideW), dilationW(dilationW) {
     // Determine whether `linalgOp` can be generated with this generator
@@ -1413,7 +1422,8 @@ struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
     resShapedType = resShaped.getType().dyn_cast<ShapedType>();
     if (!lhsShapedType || !rhsShapedType || !resShapedType)
       return;
-    if (lhsShapedType.getRank() != 3 || rhsShapedType.getRank() != 3 ||
+    if (lhsShapedType.getRank() != 3 ||
+        (rhsShapedType.getRank() != 2 && rhsShapedType.getRank() != 3) ||
         resShapedType.getRank() != 3)
       return;
 
@@ -1553,12 +1563,130 @@ struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
         /*iteratorTypes=*/ArrayRef<StringRef>{par, par, par, red});
   }
 
+  /// Generate a vector implementation for:
+  /// ```
+  ///   Op def: (     n,     w,     c,    kw)
+  ///    Iters: ({Par(), Par(), Par(), Red()})
+  ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
+  /// ```
+  /// kw is always unrolled.
+  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1.
+  FailureOr<Operation *> dilated_conv() {
+    if (!valid)
+      return failure();
+
+    int nSize = lhsShapedType.getShape()[0];
+    int wSize = resShapedType.getShape()[1];
+    int cSize = lhsShapedType.getShape()[2];
+    int kwSize = rhsShapedType.getShape()[0];
+
+    vector::TransferWriteOp write;
+    Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+
+    // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
+    // When strideW == 1, we can batch the contiguous loads and avoid unrolling
+    int64_t wSizeStep = strideW == 1 ? wSize : 1;
+
+    Type lhsEltType = lhsShapedType.getElementType();
+    Type rhsEltType = rhsShapedType.getElementType();
+    Type resEltType = resShapedType.getElementType();
+    VectorType lhsType = VectorType::get(
+        {nSize, (wSize - 1) * strideW + 1 + (kwSize - 1) * dilationW + 1,
+         cSize},
+        lhsEltType);
+    VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
+    VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
+
+    // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, 0].
+    Value lhs = builder.create<vector::TransferReadOp>(
+        loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
+    // Read rhs slice of size {kw, c} @ [0, 0].
+    Value rhs = builder.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
+                                                       ValueRange{zero, zero});
+    // Read res slice of size {n, w, c} @ [0, 0, 0].
+    Value res = builder.create<vector::TransferReadOp>(
+        loc, resType, resShaped, ValueRange{zero, zero, zero});
+
+    //===------------------------------------------------------------------===//
+    // Begin vector-only rewrite part
+    //===------------------------------------------------------------------===//
+    // Unroll along kw and read slices of lhs and rhs.
+    SmallVector<Value> lhsVals, rhsVals, resVals;
+    for (int64_t kw = 0; kw < kwSize; ++kw) {
+      // Extract rhs slice of size {c} @ [kw].
+      rhsVals.push_back(builder.create<vector::ExtractOp>(
+          loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
+
+      for (int64_t w = 0; w < wSize; w += wSizeStep) {
+        // Extract lhs slice of size {n, wSizeStep, c}
+        //   @ [0, sw * w + dw * kw, 0].
+        lhsVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
+            loc, lhs,
+            /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
+            /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
+            /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
+
+        // This does not depend on kw.
+        if (kw == 0) {
+          // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
+          resVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
+              loc, res,
+              /*offsets=*/ArrayRef<int64_t>{0, w, 0},
+              /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
+              /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
+        }
+      }
+    }
+
+    auto linearIndex = [&](int64_t kw, int64_t w) {
+      return kw * (wSize / wSizeStep) + w;
+    };
+
+    // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
+    for (int64_t kw = 0; kw < kwSize; ++kw) {
+      for (int64_t w = 0; w < wSize; w += wSizeStep) {
+        resVals[w] = dilatedConv1dSliceAsContraction(
+            builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]);
+      }
+    }
+
+    // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
+    // This does not depend on kw.
+    for (int64_t w = 0; w < wSize; w += wSizeStep) {
+      res = builder.create<vector::InsertStridedSliceOp>(
+          loc, resVals[w], res,
+          /*offsets=*/ArrayRef<int64_t>{0, w, 0},
+          /*strides=*/ArrayRef<int64_t>{1, 1, 1});
+    }
+    //===------------------------------------------------------------------===//
+    // End vector-only rewrite part
+    //===------------------------------------------------------------------===//
+
+    // Write back res slice of size {n, w, c} @ [0, 0, 0].
+    return builder
+        .create<vector::TransferWriteOp>(loc, res, resShaped,
+                                         ValueRange{zero, zero, zero})
+        .getOperation();
+  }
+
+  // Create a contraction: lhs{n, w, c} * rhs{c} -> res{n, w, c}
+  vector::ContractionOp dilatedConv1dSliceAsContraction(OpBuilder &b,
+                                                        Location loc, Value lhs,
+                                                        Value rhs, Value res) {
+    StringRef par = Par().strRef, red = Red().strRef;
+    AffineExpr n, w, c;
+    bindDims(ctx, n, w, c);
+    return builder.create<vector::ContractionOp>(
+        loc, lhs, rhs, res,
+        /*indexingMaps=*/MapList{{n, w, c}, {c}, {n, w, c}},
+        /*iteratorTypes=*/ArrayRef<StringRef>{par, par, red});
+  }
+
   /// Entry point that transposes into the common form:
   ///   {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
   FailureOr<Operation *> generateConv() {
     AffineExpr n, w, f, kw, c;
     bindDims(ctx, n, w, f, kw, c);
-
     if (!iters({Par(), Par(), Par(), Red(), Red()}))
       return failure();
 
@@ -1570,6 +1698,22 @@ struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
     return failure();
   }
 
+  /// Entry point that transposes into the common form:
+  ///   {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
+  FailureOr<Operation *> generateDilatedConv() {
+    AffineExpr n, w, c, kw;
+    bindDims(ctx, n, w, c, kw);
+    if (!iters({Par(), Par(), Par(), Red()}))
+      return failure();
+
+    // No transposition needed.
+    if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
+                /*rhsIndex*/ {kw, c},
+                /*resIndex*/ {n, w, c}}))
+      return dilated_conv();
+    return failure();
+  }
+
 private:
   bool valid;
   int strideW, dilationW;
@@ -1588,8 +1732,11 @@ vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp) {
   auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
   auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
   LinalgOp linalgOp = cast<LinalgOp>(convOp.getOperation());
-  Conv1D_NWC_WCF_Generator e(b, linalgOp, stride, dilation);
-  return e.generateConv();
+  Conv1D_NWC_Generator e(b, linalgOp, stride, dilation);
+  auto res = e.generateConv();
+  if (succeeded(res))
+    return res;
+  return e.generateDilatedConv();
 }
 
 struct VectorizeConvolution

diff  --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index 0a1cbc41d58e..aa3d3f55953c 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -180,7 +180,7 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
 // CHECK-SAME:     %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]]
 // CHECK-SAME:     : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32>
-/// w == 1, kw == 1
+/// w == 0, kw == 1
 //      CHECK:   %[[CONTRACT_1:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
@@ -189,3 +189,52 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
 
 // Write the result back in one shot.
 //      CHECK:   vector.transfer_write %[[CONTRACT_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+// -----
+
+func @depthwise_conv1d_nwc_3x5x4_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
+  linalg.depthwise_conv1D_nw
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>)
+    outs(%output : memref<3x2x4xf32>)
+  return
+}
+
+//       CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//       CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+
+//       CHECK: func @depthwise_conv1d_nwc_3x5x4_memref
+//  CHECK-SAME:   (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xf32>, %[[FILTER:[0-9a-z]+]]: memref<2x4xf32>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xf32>)
+
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+
+/// Read the whole data in one shot.
+//      CHECK:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
+//      CHECK:  %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
+//      CHECK:  %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+/// w == 0, kw == 0
+//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x4xf32>
+//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x5x4xf32> to vector<3x2x4xf32>
+/// w == 0, kw == 1
+//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x4xf32>
+//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x5x4xf32> to vector<3x2x4xf32>
+
+/// w == 0, kw == 0
+//      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[INPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-SAME:     %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]]
+// CHECK-SAME:     : vector<3x2x4xf32>, vector<4xf32> into vector<3x2x4xf32>
+/// w == 0, kw == 1
+//      CHECK:   %[[CONTRACT_1:.+]] = vector.contract {
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[INPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-SAME:     %[[V_INPUT_1]], %[[V_FILTER_1]], %[[CONTRACT_0]]
+// CHECK-SAME:     : vector<3x2x4xf32>, vector<4xf32> into vector<3x2x4xf32>
+
+// Write the result back in one shot.
+//      CHECK:   vector.transfer_write %[[CONTRACT_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]


        


More information about the Mlir-commits mailing list