[Mlir-commits] [mlir] 641fe70 - [mlir][Linalg] Fix and improve vectorization of depthwise convolutions.

Nicolas Vasilache llvmlistbot at llvm.org
Mon Nov 15 05:03:15 PST 2021


Author: Nicolas Vasilache
Date: 2021-11-15T12:58:05Z
New Revision: 641fe70776c8fb0eba5c28a4aaa4a222e184a11d

URL: https://github.com/llvm/llvm-project/commit/641fe70776c8fb0eba5c28a4aaa4a222e184a11d
DIFF: https://github.com/llvm/llvm-project/commit/641fe70776c8fb0eba5c28a4aaa4a222e184a11d.diff

LOG: [mlir][Linalg] Fix and improve vectorization of depthwise convolutions.

When trying to connect the vectorization of depthwise convolutions to e2e execution
a number of problems surfaced.
Fix an off-by-one error on the size of the input vector (similary to what was previously done for regular conv).
Rewrite the lowering to vector.fma instead of vector.contract: the KW reduction dimension has already been unrolled and vector.contract requires a reduction dimension to be valid.

Differential Revision: https://reviews.llvm.org/D113884

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorize-convolution.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0b856c2e5678..002f791a63d5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1477,7 +1477,7 @@ struct Conv1D_NWC_Generator : public StructuredGenerator<LinalgOp> {
         {nSize,
          // iw = ow * sw + kw *  dw - 1
          //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
-         // Perform the proper inclusive -> exclusive -> inclusive
+         // Perform the proper inclusive -> exclusive -> inclusive.
          ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
          cSize},
         lhsEltType);
@@ -1557,9 +1557,8 @@ struct Conv1D_NWC_Generator : public StructuredGenerator<LinalgOp> {
   }
 
   // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
-  vector::ContractionOp conv1dSliceAsContraction(OpBuilder &b, Location loc,
-                                                 Value lhs, Value rhs,
-                                                 Value res) {
+  Value conv1dSliceAsContraction(OpBuilder &b, Location loc, Value lhs,
+                                 Value rhs, Value res) {
     StringRef par = Par().strRef, red = Red().strRef;
     AffineExpr n, w, f, c;
     bindDims(ctx, n, w, f, c);
@@ -1597,7 +1596,10 @@ struct Conv1D_NWC_Generator : public StructuredGenerator<LinalgOp> {
     Type rhsEltType = rhsShapedType.getElementType();
     Type resEltType = resShapedType.getElementType();
     VectorType lhsType = VectorType::get(
-        {nSize, (wSize - 1) * strideW + 1 + (kwSize - 1) * dilationW + 1,
+        {nSize,
+         // iw = ow * sw + kw *  dw - 1
+         //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
+         ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
          cSize},
         lhsEltType);
     VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
@@ -1651,7 +1653,7 @@ struct Conv1D_NWC_Generator : public StructuredGenerator<LinalgOp> {
     // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
-        resVals[w] = dilatedConv1dSliceAsContraction(
+        resVals[w] = dilatedConv1dSliceAsFma(
             builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]);
       }
     }
@@ -1675,17 +1677,11 @@ struct Conv1D_NWC_Generator : public StructuredGenerator<LinalgOp> {
         .getOperation();
   }
 
-  // Create a contraction: lhs{n, w, c} * rhs{c} -> res{n, w, c}
-  vector::ContractionOp dilatedConv1dSliceAsContraction(OpBuilder &b,
-                                                        Location loc, Value lhs,
-                                                        Value rhs, Value res) {
-    StringRef par = Par().strRef, red = Red().strRef;
-    AffineExpr n, w, c;
-    bindDims(ctx, n, w, c);
-    return builder.create<vector::ContractionOp>(
-        loc, lhs, rhs, res,
-        /*indexingMaps=*/MapList{{n, w, c}, {c}, {n, w, c}},
-        /*iteratorTypes=*/ArrayRef<StringRef>{par, par, red});
+  /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to fma.
+  Value dilatedConv1dSliceAsFma(OpBuilder &b, Location loc, Value lhs,
+                                Value rhs, Value res) {
+    Value bcast = builder.create<vector::BroadcastOp>(loc, res.getType(), rhs);
+    return b.create<vector::FMAOp>(loc, lhs, bcast, res);
   }
 
   /// Entry point that transposes into the common form:

diff  --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index 381fcb9a9f44..e545e7f191aa 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -200,9 +200,6 @@ func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filter: m
   return
 }
 
-//       CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-//       CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2)>
-
 //       CHECK: func @depthwise_conv1d_nwc_wc_3x5x4_memref
 //  CHECK-SAME:   (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xf32>, %[[FILTER:[0-9a-z]+]]: memref<2x4xf32>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xf32>)
 
@@ -217,24 +214,19 @@ func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filter: m
 /// w == 0, kw == 0
 //      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x4xf32>
 //      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
-// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x5x4xf32> to vector<3x2x4xf32>
+// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
 /// w == 0, kw == 1
 //      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x4xf32>
 //      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
-// CHECK-SAME:     {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x5x4xf32> to vector<3x2x4xf32>
+// CHECK-SAME:     {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
 
-/// w == 0, kw == 0
-//      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
-// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[INPUT_MAP]]],
-// CHECK-SAME:       iterator_types = ["parallel", "parallel", "reduction"]}
-// CHECK-SAME:     %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]]
-// CHECK-SAME:     : vector<3x2x4xf32>, vector<4xf32> into vector<3x2x4xf32>
-/// w == 0, kw == 1
-//      CHECK:   %[[CONTRACT_1:.+]] = vector.contract {
-// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[INPUT_MAP]]],
-// CHECK-SAME:       iterator_types = ["parallel", "parallel", "reduction"]}
-// CHECK-SAME:     %[[V_INPUT_1]], %[[V_FILTER_1]], %[[CONTRACT_0]]
-// CHECK-SAME:     : vector<3x2x4xf32>, vector<4xf32> into vector<3x2x4xf32>
+/// w == 0, kw = 0
+//      CHECK:  %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xf32> to vector<3x2x4xf32>
+//      CHECK:  %[[FMA_0:.*]] = vector.fma %[[V_INPUT_0]], %[[B_FILTER_0]], %[[V_OUTPUT_R]] : vector<3x2x4xf32>
+
+/// w == 0, kw = 1
+//      CHECK:  %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xf32> to vector<3x2x4xf32>
+//      CHECK:  %[[FMA_1:.*]] = vector.fma %[[V_INPUT_1]], %[[B_FILTER_1]], %[[FMA_0]] : vector<3x2x4xf32>
 
 // Write the result back in one shot.
-//      CHECK:   vector.transfer_write %[[CONTRACT_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+//      CHECK:   vector.transfer_write %[[FMA_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]


        


More information about the Mlir-commits mailing list