[Mlir-commits] [mlir] 9c49717 - [mlir][Linalg] Refactor vectorization of conv1d more aggressively.

Nicolas Vasilache llvmlistbot at llvm.org
Wed Nov 3 01:29:39 PDT 2021


Author: Nicolas Vasilache
Date: 2021-11-03T08:18:01Z
New Revision: 9c4971740b875530d21d3b73d8843fa88249085f

URL: https://github.com/llvm/llvm-project/commit/9c4971740b875530d21d3b73d8843fa88249085f
DIFF: https://github.com/llvm/llvm-project/commit/9c4971740b875530d21d3b73d8843fa88249085f.diff

LOG: [mlir][Linalg] Refactor vectorization of conv1d more aggressively.

This better decouples transfer read/write from vector-only rewrite of conv.
This form is close to ready to plop into a new vector.conv op and the vector.transfer operations to be generalized as part of generic vectorization once the properties ConvolutionOpInterface are inferred from the indexing maps.

This also results in a nice perf boost in the dw == 1 cases.

Differential revision: https://reviews.llvm.org/D112822

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorize-convolution.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0678563cb2d3..c65d2a1de869 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1458,60 +1458,86 @@ struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
     // When strideW == 1, we can batch the contiguous loads and avoid unrolling
     int64_t wSizeStep = strideW == 1 ? wSize : 1;
 
-    VectorType lhsType = VectorType::get({nSize, wSizeStep, cSize},
-                                         lhsShapedType.getElementType());
-    VectorType rhsType =
-        VectorType::get({cSize, fSize}, rhsShapedType.getElementType());
-    VectorType resType = VectorType::get({nSize, wSizeStep, fSize},
-                                         resShapedType.getElementType());
-
-    SmallVector<Value> lhsVals, rhsVals, resVals;
+    Type lhsEltType = lhsShapedType.getElementType();
+    Type rhsEltType = rhsShapedType.getElementType();
+    Type resEltType = resShapedType.getElementType();
+    VectorType lhsType = VectorType::get(
+        {nSize, (wSize - 1) * strideW + 1 + (kwSize - 1) * dilationW + 1,
+         cSize},
+        lhsEltType);
+    VectorType rhsType = VectorType::get({kwSize, cSize, fSize}, rhsEltType);
+    VectorType resType = VectorType::get({nSize, wSize, fSize}, resEltType);
+
+    // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0, 0].
+    Value lhs = builder.create<vector::TransferReadOp>(
+        loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
+    // Read rhs slice of size {kw, c, f} @ [0, 0, 0].
+    Value rhs = builder.create<vector::TransferReadOp>(
+        loc, rhsType, rhsShaped, ValueRange{zero, zero, zero});
+    // Read res slice of size {n, w, f} @ [0, 0, 0].
+    Value res = builder.create<vector::TransferReadOp>(
+        loc, resType, resShaped, ValueRange{zero, zero, zero});
+
+    //===------------------------------------------------------------------===//
+    // Begin vector-only rewrite part
+    //===------------------------------------------------------------------===//
     // Unroll along kw and read slices of lhs and rhs.
-    // Alternatively we could preload both 3-d slices and extract smaller slices
-    // iteratively without touching memory. But this will quickly spill.
+    SmallVector<Value> lhsVals, rhsVals, resVals;
     for (int64_t kw = 0; kw < kwSize; ++kw) {
-      // Read rhs slice of size {c, f} @ [kw, 0, 0].
-      Value kwVal = builder.create<arith::ConstantIndexOp>(loc, kw);
-      rhsVals.push_back(builder.create<vector::TransferReadOp>(
-          loc, rhsType, rhsShaped, ValueRange{kwVal, zero, zero}));
+      // Extract rhs slice of size {c, f} @ [kw].
+      rhsVals.push_back(builder.create<vector::ExtractOp>(
+          loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
 
-      for (int64_t w_iv = 0; w_iv < wSize; w_iv += wSizeStep) {
-        // Read lhs slice of size {n, wSizeStep, c}
+      for (int64_t w = 0; w < wSize; w += wSizeStep) {
+        // Extract lhs slice of size {n, wSizeStep, c}
         //   @ [0, sw * w + dw * kw, 0].
-        Value lhsStridedIdx = builder.create<arith::ConstantIndexOp>(
-            loc, strideW * w_iv + dilationW * kw);
-        lhsVals.push_back(builder.create<vector::TransferReadOp>(
-            loc, lhsType, lhsShaped, ValueRange{zero, lhsStridedIdx, zero}));
-
-        // Read res slice: {n, wSizeStep, f} @ [0, w, 0].
-        Value wVal = builder.create<arith::ConstantIndexOp>(loc, w_iv);
-        // When operating on tensors, reading from the updated value is required
-        // for vector.transfer_read/write hoisting to function as expected.
-        resVals.push_back(builder.create<vector::TransferReadOp>(
-            loc, resType, resShaped, ValueRange{zero, wVal, zero}));
-      }
-    }
-    for (int64_t kw = 0; kw < kwSize; ++kw) {
-      for (int64_t w_iv = 0; w_iv < wSize; w_iv += wSizeStep) {
-        // Compute contraction: I{n, w, c} * F{c, f} -> O{n, w, f}
-        resVals[kw * (wSize / wSizeStep) + w_iv] = conv1dSliceAsContraction(
-            builder, loc, lhsVals[kw * (wSize / wSizeStep) + w_iv], rhsVals[kw],
-            resVals[kw * (wSize / wSizeStep) + w_iv]);
+        lhsVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
+            loc, lhs,
+            /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
+            /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
+            /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
+
+        // This does not depend on kw.
+        if (kw == 0) {
+          // Extract res slice: {n, wSizeStep, f} @ [0, w, 0].
+          resVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
+              loc, res,
+              /*offsets=*/ArrayRef<int64_t>{0, w, 0},
+              /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, fSize},
+              /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
+        }
       }
     }
+
+    auto linearIndex = [&](int64_t kw, int64_t w) {
+      return kw * (wSize / wSizeStep) + w;
+    };
+
+    // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
     for (int64_t kw = 0; kw < kwSize; ++kw) {
-      for (int64_t w_iv = 0; w_iv < wSize; w_iv += wSizeStep) {
-        Value wVal = builder.create<arith::ConstantIndexOp>(loc, w_iv);
-        // Write back res slice: {n, wSizeStep, f} @ [0, w, 0].
-        write = builder.create<vector::TransferWriteOp>(
-            loc, resVals[kw * (wSize / wSizeStep) + w_iv], resShaped,
-            ValueRange{zero, wVal, zero});
-        if (write.getNumResults() == 1)
-          resShaped = write->getResult(0);
+      for (int64_t w = 0; w < wSize; w += wSizeStep) {
+        resVals[w] = conv1dSliceAsContraction(
+            builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]);
       }
     }
 
-    return write.getOperation();
+    // Write back res slice: {n, wSizeStep, f} @ [0, w, 0].
+    // This does not depend on kw.
+    for (int64_t w = 0; w < wSize; w += wSizeStep) {
+      res = builder.create<vector::InsertStridedSliceOp>(
+          loc, resVals[w], res,
+          /*offsets=*/ArrayRef<int64_t>{0, w, 0},
+          /*strides=*/ArrayRef<int64_t>{1, 1, 1});
+    }
+    //===------------------------------------------------------------------===//
+    // End vector-only rewrite part
+    //===------------------------------------------------------------------===//
+
+    // Write back res slice of size {n, w, f} @ [0, 0, 0].
+    return builder
+        .create<vector::TransferWriteOp>(loc, res, resShaped,
+                                         ValueRange{zero, zero, zero})
+        .getOperation();
   }
 
   // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}

diff  --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index 31be54ec219c..0a1cbc41d58e 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -16,35 +16,48 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x3x8xf
 // CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<1x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x8xf32>)
 
 //  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
 //  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
 
+/// Read the whole data in one shot.
+//  CHECK-DAG:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//  CHECK-DAG:  %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//  CHECK-DAG:  %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+
+//      CHECK:    %[[V_FILTER:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<1x3x8xf32>
 /// w == 0, kw == 0
-//      CHECK:   %[[V_FILTER:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
-//      CHECK:   %[[V_INPUT0:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
-//      CHECK:   %[[V_OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x5x3xf32> to vector<4x1x3xf32>
+//      CHECK:  %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
 /// w == 1, kw == 0
-//      CHECK:   %[[V_INPUT3:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C3]], %[[C0]]], %[[F0]]
-//      CHECK:   %[[V_OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]]
-//      CHECK:   %[[CONTRACT0:.+]] = vector.contract {
+//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x5x3xf32> to vector<4x1x3xf32>
+//      CHECK:  %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME:     {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
 
 /// w == 0, kw == 0
+//      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
-// CHECK-SAME:     %[[V_INPUT0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
+// CHECK-SAME:     %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
 // CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
+
 /// w == 1, kw == 0
-//      CHECK:   %[[CONTRACT1:.+]] = vector.contract {
+//      CHECK:   %[[CONTRACT_1:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
-// CHECK-SAME:     %[[V_INPUT3]], %[[V_FILTER]], %[[V_OUTPUT_1]]
+// CHECK-SAME:     %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]]
 // CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
 
 /// w == 0, kw == 0
-//      CHECK:   vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+//      CHECK:   %[[RES_0:.+]] = vector.insert_strided_slice %[[CONTRACT_0]], %[[V_OUTPUT_R]]
+// CHECK-SAME:     {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32>
 /// w == 1, kw == 0
-//      CHECK:   vector.transfer_write %[[CONTRACT1]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]]
+//      CHECK:   %[[RES_1:.+]] = vector.insert_strided_slice %[[CONTRACT_1]], %[[RES_0]]
+// CHECK-SAME:     {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32>
+
+// Write the result back in one shot.
+//      CHECK:   vector.transfer_write %[[RES_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
 
 // -----
 
@@ -64,104 +77,115 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
 // CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<2x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x8xf32>)
 
 //  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
-//  CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
-//  CHECK-DAG:   %[[C5:.+]] = arith.constant 5 : index
 //  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
 
+/// Read the whole data in one shot.
+//  CHECK-DAG:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//  CHECK-DAG:   %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//  CHECK-DAG:   %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+
+
 /// w == 0, kw == 0
-//      CHECK:   %[[V_FILTER_A:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
-//      CHECK:   %[[V_INPUT0_A:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
-//      CHECK:   %[[V_OUTPUT_0_A:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32>
+//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x7x3xf32> to vector<4x1x3xf32>
+//      CHECK:  %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
+/// w == 1, kw == 0
+//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x7x3xf32> to vector<4x1x3xf32>
+//      CHECK:  %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME:     {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
+
 /// w == 0, kw == 1
-//      CHECK:   %[[V_INPUT3_A:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C3]], %[[C0]]], %[[F0]]
-//      CHECK:   %[[V_OUTPUT_1_A:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]]
+//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32>
+//      CHECK:   %[[V_INPUT_2:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 2, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x7x3xf32> to vector<4x1x3xf32>
 /// w == 1, kw == 0
-//      CHECK:   %[[V_FILTER_B:.+]]   = vector.transfer_read %[[FILTER]][%[[C1]], %[[C0]], %[[C0]]], %[[F0]]
-//      CHECK:   %[[V_INPUT0_B:.+]]   = vector.transfer_read  %[[INPUT]][%[[C0]], %[[C2]], %[[C0]]], %[[F0]]
-//      CHECK:   %[[V_OUTPUT_0_B:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
-/// w == 1, kw == 1
-//      CHECK:     %[[V_INPUT3_B:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C5]], %[[C0]]], %[[F0]]
-//      CHECK:   %[[V_OUTPUT_1_B:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]]
+//      CHECK:   %[[V_INPUT_3:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 5, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x7x3xf32> to vector<4x1x3xf32>
 
 /// w == 0, kw == 0
-//      CHECK:   %[[CONTRACT0_A:.+]] = vector.contract {
+//      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
-// CHECK-SAME:     %[[V_INPUT0_A]], %[[V_FILTER_A]], %[[V_OUTPUT_0_A]]
+// CHECK-SAME:     %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_0]]
 // CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
-/// w == 0, kw == 1
-//      CHECK:   %[[CONTRACT1_A:.+]] = vector.contract {
+/// w == 1, kw == 0
+//      CHECK:   %[[CONTRACT_1:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
-// CHECK-SAME:     %[[V_INPUT3_A]], %[[V_FILTER_A]], %[[V_OUTPUT_1_A]]
+// CHECK-SAME:     %[[V_INPUT_1]], %[[V_FILTER_0]], %[[V_OUTPUT_1]]
 // CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
-/// w == 1, kw == 0
-//      CHECK:   %[[CONTRACT0_B:.+]] = vector.contract {
+/// w == 1, kw == 1
+//      CHECK:   %[[CONTRACT_2:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
-// CHECK-SAME:     %[[V_INPUT0_B]], %[[V_FILTER_B]], %[[V_OUTPUT_0_B]]
+// CHECK-SAME:     %[[V_INPUT_2]], %[[V_FILTER_1]], %[[CONTRACT_0]]
 // CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
 /// w == 1, kw == 1
-//      CHECK:   %[[CONTRACT1_B:.+]] = vector.contract {
+//      CHECK:   %[[CONTRACT_3:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
-// CHECK-SAME:     %[[V_INPUT3_B]], %[[V_FILTER_B]], %[[V_OUTPUT_1_B]]
+// CHECK-SAME:     %[[V_INPUT_3]], %[[V_FILTER_1]], %[[CONTRACT_1]]
 // CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
 
 /// w == 0, kw == 0
-//      CHECK:   vector.transfer_write %[[CONTRACT0_A]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
-/// w == 0, kw == 1
-//      CHECK:   vector.transfer_write %[[CONTRACT1_A]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]]
+//      CHECK:   %[[RES_0:.+]] = vector.insert_strided_slice %[[CONTRACT_2]], %[[V_OUTPUT_R]]
+// CHECK-SAME:     {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32>
 /// w == 1, kw == 0
-//      CHECK:   vector.transfer_write %[[CONTRACT0_B]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
-/// w == 1, kw == 1
-//      CHECK:   vector.transfer_write %[[CONTRACT1_B]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]]
+//      CHECK:   %[[RES_1:.+]] = vector.insert_strided_slice %[[CONTRACT_3]], %[[RES_0]]
+// CHECK-SAME:     {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32>
+
+// Write the result back in one shot.
+//      CHECK:   vector.transfer_write %[[RES_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
 
 // -----
 
+func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf32>, %output: memref<4x2x8xf32>) {
+  linalg.conv_1d_nwc_wcf
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : memref<4x6x3xf32>, memref<2x3x8xf32>)
+    outs(%output : memref<4x2x8xf32>)
+  return
+}
+
 // CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
 // CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
 // CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
 
 //      CHECK: func @conv1d_nwc_4x2x8_memref
 // CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<2x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x8xf32>)
-func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf32>, %output: memref<4x2x8xf32>) {
+
 //  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
 //  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
 
+/// Read the whole data in one shot.
+//  CHECK-DAG:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//  CHECK-DAG:  %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//  CHECK-DAG:  %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+
 /// w == 0, kw == 0
-//      CHECK:   %[[V_FILTER_000:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<3x8xf32>
-//      CHECK:   %[[V_INPUT_000:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x3xf32>
-//      CHECK:   %[[V_OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x8xf32>
+//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32>
+//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x5x3xf32> to vector<4x2x3xf32>
 /// w == 0, kw == 1
-//      CHECK:   %[[V_FILTER_100:.+]] = vector.transfer_read %[[FILTER]][%[[C1]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<3x8xf32>
-//      CHECK:   %[[V_INPUT_020:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C2]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x3xf32>
-//      CHECK:   %[[V_OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x8xf32>
+//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32>
+//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 2, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x5x3xf32> to vector<4x2x3xf32>
 
 /// w == 0, kw == 0
-//      CHECK:   %[[CONTRACT0:.+]] = vector.contract {
+//      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
-// CHECK-SAME:     %[[V_INPUT_000]], %[[V_FILTER_000]], %[[V_OUTPUT_0]]
+// CHECK-SAME:     %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]]
 // CHECK-SAME:     : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32>
-/// w == 0, kw == 1
-//      CHECK:   %[[CONTRACT1:.+]] = vector.contract {
+/// w == 1, kw == 1
+//      CHECK:   %[[CONTRACT_1:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
-// CHECK-SAME:     %[[V_INPUT_020]], %[[V_FILTER_100]], %[[V_OUTPUT_1]]
+// CHECK-SAME:     %[[V_INPUT_1]], %[[V_FILTER_1]], %[[CONTRACT_0]]
 // CHECK-SAME:     : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32>
 
-/// w == 0, kw == 0
-//      CHECK:   vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
-/// w == 0, kw == 1
-//      CHECK:   vector.transfer_write %[[CONTRACT1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
-  linalg.conv_1d_nwc_wcf
-    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
-    ins(%input, %filter : memref<4x6x3xf32>, memref<2x3x8xf32>)
-    outs(%output : memref<4x2x8xf32>)
-  return
-}
+// Write the result back in one shot.
+//      CHECK:   vector.transfer_write %[[CONTRACT_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]


        


More information about the Mlir-commits mailing list