[Mlir-commits] [mlir] 203accf - [mlir][Linalg] Improve conv vectorization for the stride==1 case.

Nicolas Vasilache llvmlistbot at llvm.org
Thu Oct 21 08:18:41 PDT 2021


Author: Nicolas Vasilache
Date: 2021-10-21T15:18:28Z
New Revision: 203accf0bdde1d276646c79dfa605ee3426f1ca8

URL: https://github.com/llvm/llvm-project/commit/203accf0bdde1d276646c79dfa605ee3426f1ca8
DIFF: https://github.com/llvm/llvm-project/commit/203accf0bdde1d276646c79dfa605ee3426f1ca8.diff

LOG: [mlir][Linalg] Improve conv vectorization for the stride==1 case.

In the stride == 1 case, conv1d reads contiguous data along the input dimension. This can be advantageaously used to bulk memory transfers and compute while avoiding unrolling. Experimentally, this can yield speedups of up to 50%.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D112139

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorize-convolution.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 1b05cd6c378d..bff12e6f5c86 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1505,7 +1505,7 @@ struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
   ///    Iters: ({Par(), Par(), Par(), Red(), Red()})
   ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
   /// ```
-  /// w and kw are unrolled.
+  /// kw is always unrolled.
   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1.
   FailureOr<Operation *> conv() {
     if (!valid)
@@ -1520,47 +1520,50 @@ struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
     vector::TransferWriteOp write;
     Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
 
+    int64_t wSizeStep = strideW == 1 ? wSize : 1;
+
     // Unroll along kw and read slices of lhs and rhs.
     // Alternatively we could preload both 3-d slices and extract smaller slices
     // iteratively without touching memory. But this will quickly spill.
     for (int64_t kw = 0; kw < kwSize; ++kw) {
-      // Read rhs slice of size {1, c, f} @ [kw, 0, 0].
+      // Read rhs slice of size {c, f} @ [kw, 0, 0].
       Value kwVal = builder.create<arith::ConstantIndexOp>(loc, kw);
       VectorType rhsType =
-          VectorType::get({1, cSize, fSize}, rhsShapedType.getElementType());
+          VectorType::get({cSize, fSize}, rhsShapedType.getElementType());
       Value rhs = builder.create<vector::TransferReadOp>(
           loc, rhsType, rhsShaped, ValueRange{kwVal, zero, zero});
 
-      for (int64_t w = 0; w < wSize; ++w) {
-        // Read lhs slice of size {n, 1, c} @ [0, sw * w + dw * kw, 0].
+      for (int64_t w_iv = 0; w_iv < wSize; w_iv += wSizeStep) {
+        // Read lhs slice of size {n, wSizeStep, c}
+        //   @ [0, sw * w + dw * kw, 0].
         Value lhsStridedIdx = builder.create<arith::ConstantIndexOp>(
-            loc, strideW * w + dilationW * kw);
-        VectorType lhsType =
-            VectorType::get({nSize, 1, cSize}, lhsShapedType.getElementType());
+            loc, strideW * w_iv + dilationW * kw);
+        VectorType lhsType = VectorType::get({nSize, wSizeStep, cSize},
+                                             lhsShapedType.getElementType());
         Value lhs = builder.create<vector::TransferReadOp>(
             loc, lhsType, lhsShaped, ValueRange{zero, lhsStridedIdx, zero});
 
-        // Read res slice: {n, 1, f} @ [0, w, 0].
-        Value wVal = builder.create<arith::ConstantIndexOp>(loc, w);
-        VectorType resType =
-            VectorType::get({nSize, 1, fSize}, resShapedType.getElementType());
+        // Read res slice: {n, wSizeStep, f} @ [0, w, 0].
+        Value wVal = builder.create<arith::ConstantIndexOp>(loc, w_iv);
+        VectorType resType = VectorType::get({nSize, wSizeStep, fSize},
+                                             resShapedType.getElementType());
         // When operating on tensors, reading from the updated value is required
         // for vector.transfer_read/write hoisting to function as expected.
         Value res = builder.create<vector::TransferReadOp>(
             loc, resType, resShaped, ValueRange{zero, wVal, zero});
 
-        // Compute contraction: I{n, 1, c} * F{1, c, f} -> O{n, 1, f}
+        // Compute contraction: I{n, w, c} * F{c, f} -> O{n, w, f}
         StringRef par = Par().strRef, red = Red().strRef;
-        AffineExpr n, one, f, c;
-        bindDims(ctx, n, one, f, c);
+        AffineExpr n, w, f, c;
+        bindDims(ctx, n, w, f, c);
         // clang-format off
         res = builder.create<vector::ContractionOp>(
           loc, lhs, rhs, res,
-          /*indexingMaps=*/MapList{{n, one, c}, {one, c, f}, {n, one, f}},
+          /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
           /*iteratorTypes=*/ArrayRef<StringRef>{par, par, par, red});
         // clang-format on
 
-        // Write back res slice: {n, 1, f} @ [0, w, 0].
+        // Write back res slice: {n, wSizeStep, f} @ [0, w, 0].
         write = builder.create<vector::TransferWriteOp>(
             loc, res, resShaped, ValueRange{zero, wVal, zero});
         if (write.getNumResults() == 1)

diff  --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index a7c6f47cda7f..b1802fded0b1 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -9,7 +9,7 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x3x8xf
 }
 
 // CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3, d2)>
+// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
 // CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
 
 //      CHECK: func @conv1d_nwc_4x2x8_memref
@@ -28,7 +28,7 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x3x8xf
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
 // CHECK-SAME:     %[[V_INPUT0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
-// CHECK-SAME:     : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32>
+// CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
 //      CHECK:   vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
 
 /// w == 1, kw == 0
@@ -38,7 +38,7 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x3x8xf
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
 // CHECK-SAME:     %[[V_INPUT3]], %[[V_FILTER]], %[[V_OUTPUT_1]]
-// CHECK-SAME:     : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32>
+// CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
 //      CHECK:   vector.transfer_write %[[CONTRACT1]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]]
 
 // -----
@@ -52,7 +52,7 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
 }
 
 // CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3, d2)>
+// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
 // CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
 
 //      CHECK: func @conv1d_nwc_4x2x8_memref
@@ -73,7 +73,7 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
 // CHECK-SAME:     %[[V_INPUT0_A]], %[[V_FILTER_A]], %[[V_OUTPUT_0_A]]
-// CHECK-SAME:     : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32>
+// CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
 //      CHECK:   vector.transfer_write %[[CONTRACT0_A]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
 
 /// w == 0, kw == 1
@@ -83,7 +83,7 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
 // CHECK-SAME:     %[[V_INPUT3_A]], %[[V_FILTER_A]], %[[V_OUTPUT_1_A]]
-// CHECK-SAME:     : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32>
+// CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
 //      CHECK:   vector.transfer_write %[[CONTRACT1_A]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]]
 
 /// w == 1, kw == 0
@@ -94,7 +94,7 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
 // CHECK-SAME:     %[[V_INPUT0_B]], %[[V_FILTER_B]], %[[V_OUTPUT_0_B]]
-// CHECK-SAME:     : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32>
+// CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
 //      CHECK:   vector.transfer_write %[[CONTRACT0_B]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
 
 /// w == 1, kw == 1
@@ -104,5 +104,49 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
 // CHECK-SAME:     %[[V_INPUT3_B]], %[[V_FILTER_B]], %[[V_OUTPUT_1_B]]
-// CHECK-SAME:     : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32>
+// CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
 //      CHECK:   vector.transfer_write %[[CONTRACT1_B]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]]
+
+// -----
+
+
+
+// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+//      CHECK: func @conv1d_nwc_4x2x8_memref
+// CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<2x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x8xf32>)
+func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf32>, %output: memref<4x2x8xf32>) {
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+
+/// w == 0, kw == 0
+//      CHECK:   %[[V_FILTER_000:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<3x8xf32>
+//      CHECK:   %[[V_INPUT_000:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x3xf32>
+//      CHECK:   %[[V_OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x8xf32>
+//      CHECK:   %[[CONTRACT0:.+]] = vector.contract {
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+// CHECK-SAME:     %[[V_INPUT_000]], %[[V_FILTER_000]], %[[V_OUTPUT_0]]
+// CHECK-SAME:     : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32>
+//      CHECK:   vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+/// w == 0, kw == 1
+//      CHECK:   %[[V_FILTER_100:.+]] = vector.transfer_read %[[FILTER]][%[[C1]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<3x8xf32>
+//      CHECK:   %[[V_INPUT_020:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C2]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x3xf32>
+//      CHECK:   %[[V_OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x8xf32>
+//      CHECK:   %[[CONTRACT1:.+]] = vector.contract {
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+// CHECK-SAME:     %[[V_INPUT_020]], %[[V_FILTER_100]], %[[V_OUTPUT_1]]
+// CHECK-SAME:     : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32>
+//      CHECK:   vector.transfer_write %[[CONTRACT1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+  linalg.conv_1d_nwc_wcf
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : memref<4x6x3xf32>, memref<2x3x8xf32>)
+    outs(%output : memref<4x2x8xf32>)
+  return
+}


        


More information about the Mlir-commits mailing list