[Mlir-commits] [mlir] 9c49717 - [mlir][Linalg] Refactor vectorization of conv1d more aggressively.
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Nov 3 01:29:39 PDT 2021
Author: Nicolas Vasilache
Date: 2021-11-03T08:18:01Z
New Revision: 9c4971740b875530d21d3b73d8843fa88249085f
URL: https://github.com/llvm/llvm-project/commit/9c4971740b875530d21d3b73d8843fa88249085f
DIFF: https://github.com/llvm/llvm-project/commit/9c4971740b875530d21d3b73d8843fa88249085f.diff
LOG: [mlir][Linalg] Refactor vectorization of conv1d more aggressively.
This better decouples transfer read/write from vector-only rewrite of conv.
This form is close to ready to plop into a new vector.conv op and the vector.transfer operations to be generalized as part of generic vectorization once the properties ConvolutionOpInterface are inferred from the indexing maps.
This also results in a nice perf boost in the dw == 1 cases.
Differential revision: https://reviews.llvm.org/D112822
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorize-convolution.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0678563cb2d3..c65d2a1de869 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1458,60 +1458,86 @@ struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
// When strideW == 1, we can batch the contiguous loads and avoid unrolling
int64_t wSizeStep = strideW == 1 ? wSize : 1;
- VectorType lhsType = VectorType::get({nSize, wSizeStep, cSize},
- lhsShapedType.getElementType());
- VectorType rhsType =
- VectorType::get({cSize, fSize}, rhsShapedType.getElementType());
- VectorType resType = VectorType::get({nSize, wSizeStep, fSize},
- resShapedType.getElementType());
-
- SmallVector<Value> lhsVals, rhsVals, resVals;
+ Type lhsEltType = lhsShapedType.getElementType();
+ Type rhsEltType = rhsShapedType.getElementType();
+ Type resEltType = resShapedType.getElementType();
+ VectorType lhsType = VectorType::get(
+ {nSize, (wSize - 1) * strideW + 1 + (kwSize - 1) * dilationW + 1,
+ cSize},
+ lhsEltType);
+ VectorType rhsType = VectorType::get({kwSize, cSize, fSize}, rhsEltType);
+ VectorType resType = VectorType::get({nSize, wSize, fSize}, resEltType);
+
+ // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0, 0].
+ Value lhs = builder.create<vector::TransferReadOp>(
+ loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
+ // Read rhs slice of size {kw, c, f} @ [0, 0, 0].
+ Value rhs = builder.create<vector::TransferReadOp>(
+ loc, rhsType, rhsShaped, ValueRange{zero, zero, zero});
+ // Read res slice of size {n, w, f} @ [0, 0, 0].
+ Value res = builder.create<vector::TransferReadOp>(
+ loc, resType, resShaped, ValueRange{zero, zero, zero});
+
+ //===------------------------------------------------------------------===//
+ // Begin vector-only rewrite part
+ //===------------------------------------------------------------------===//
// Unroll along kw and read slices of lhs and rhs.
- // Alternatively we could preload both 3-d slices and extract smaller slices
- // iteratively without touching memory. But this will quickly spill.
+ SmallVector<Value> lhsVals, rhsVals, resVals;
for (int64_t kw = 0; kw < kwSize; ++kw) {
- // Read rhs slice of size {c, f} @ [kw, 0, 0].
- Value kwVal = builder.create<arith::ConstantIndexOp>(loc, kw);
- rhsVals.push_back(builder.create<vector::TransferReadOp>(
- loc, rhsType, rhsShaped, ValueRange{kwVal, zero, zero}));
+ // Extract rhs slice of size {c, f} @ [kw].
+ rhsVals.push_back(builder.create<vector::ExtractOp>(
+ loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
- for (int64_t w_iv = 0; w_iv < wSize; w_iv += wSizeStep) {
- // Read lhs slice of size {n, wSizeStep, c}
+ for (int64_t w = 0; w < wSize; w += wSizeStep) {
+ // Extract lhs slice of size {n, wSizeStep, c}
// @ [0, sw * w + dw * kw, 0].
- Value lhsStridedIdx = builder.create<arith::ConstantIndexOp>(
- loc, strideW * w_iv + dilationW * kw);
- lhsVals.push_back(builder.create<vector::TransferReadOp>(
- loc, lhsType, lhsShaped, ValueRange{zero, lhsStridedIdx, zero}));
-
- // Read res slice: {n, wSizeStep, f} @ [0, w, 0].
- Value wVal = builder.create<arith::ConstantIndexOp>(loc, w_iv);
- // When operating on tensors, reading from the updated value is required
- // for vector.transfer_read/write hoisting to function as expected.
- resVals.push_back(builder.create<vector::TransferReadOp>(
- loc, resType, resShaped, ValueRange{zero, wVal, zero}));
- }
- }
- for (int64_t kw = 0; kw < kwSize; ++kw) {
- for (int64_t w_iv = 0; w_iv < wSize; w_iv += wSizeStep) {
- // Compute contraction: I{n, w, c} * F{c, f} -> O{n, w, f}
- resVals[kw * (wSize / wSizeStep) + w_iv] = conv1dSliceAsContraction(
- builder, loc, lhsVals[kw * (wSize / wSizeStep) + w_iv], rhsVals[kw],
- resVals[kw * (wSize / wSizeStep) + w_iv]);
+ lhsVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
+ loc, lhs,
+ /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
+ /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
+ /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
+
+ // This does not depend on kw.
+ if (kw == 0) {
+ // Extract res slice: {n, wSizeStep, f} @ [0, w, 0].
+ resVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
+ loc, res,
+ /*offsets=*/ArrayRef<int64_t>{0, w, 0},
+ /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, fSize},
+ /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
+ }
}
}
+
+ auto linearIndex = [&](int64_t kw, int64_t w) {
+ return kw * (wSize / wSizeStep) + w;
+ };
+
+ // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
for (int64_t kw = 0; kw < kwSize; ++kw) {
- for (int64_t w_iv = 0; w_iv < wSize; w_iv += wSizeStep) {
- Value wVal = builder.create<arith::ConstantIndexOp>(loc, w_iv);
- // Write back res slice: {n, wSizeStep, f} @ [0, w, 0].
- write = builder.create<vector::TransferWriteOp>(
- loc, resVals[kw * (wSize / wSizeStep) + w_iv], resShaped,
- ValueRange{zero, wVal, zero});
- if (write.getNumResults() == 1)
- resShaped = write->getResult(0);
+ for (int64_t w = 0; w < wSize; w += wSizeStep) {
+ resVals[w] = conv1dSliceAsContraction(
+ builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]);
}
}
- return write.getOperation();
+ // Write back res slice: {n, wSizeStep, f} @ [0, w, 0].
+ // This does not depend on kw.
+ for (int64_t w = 0; w < wSize; w += wSizeStep) {
+ res = builder.create<vector::InsertStridedSliceOp>(
+ loc, resVals[w], res,
+ /*offsets=*/ArrayRef<int64_t>{0, w, 0},
+ /*strides=*/ArrayRef<int64_t>{1, 1, 1});
+ }
+ //===------------------------------------------------------------------===//
+ // End vector-only rewrite part
+ //===------------------------------------------------------------------===//
+
+ // Write back res slice of size {n, w, f} @ [0, 0, 0].
+ return builder
+ .create<vector::TransferWriteOp>(loc, res, resShaped,
+ ValueRange{zero, zero, zero})
+ .getOperation();
}
// Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index 31be54ec219c..0a1cbc41d58e 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -16,35 +16,48 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x3x8xf
// CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<1x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x8xf32>)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+/// Read the whole data in one shot.
+// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+
+// CHECK: %[[V_FILTER:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<1x3x8xf32>
/// w == 0, kw == 0
-// CHECK: %[[V_FILTER:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
-// CHECK: %[[V_INPUT0:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
-// CHECK: %[[V_OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x5x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
/// w == 1, kw == 0
-// CHECK: %[[V_INPUT3:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C3]], %[[C0]]], %[[F0]]
-// CHECK: %[[V_OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]]
-// CHECK: %[[CONTRACT0:.+]] = vector.contract {
+// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x5x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME: {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
/// w == 0, kw == 0
+// CHECK: %[[CONTRACT_0:.+]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
-// CHECK-SAME: %[[V_INPUT0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
+// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
+
/// w == 1, kw == 0
-// CHECK: %[[CONTRACT1:.+]] = vector.contract {
+// CHECK: %[[CONTRACT_1:.+]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
-// CHECK-SAME: %[[V_INPUT3]], %[[V_FILTER]], %[[V_OUTPUT_1]]
+// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]]
// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
/// w == 0, kw == 0
-// CHECK: vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK: %[[RES_0:.+]] = vector.insert_strided_slice %[[CONTRACT_0]], %[[V_OUTPUT_R]]
+// CHECK-SAME: {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32>
/// w == 1, kw == 0
-// CHECK: vector.transfer_write %[[CONTRACT1]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]]
+// CHECK: %[[RES_1:.+]] = vector.insert_strided_slice %[[CONTRACT_1]], %[[RES_0]]
+// CHECK-SAME: {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32>
+
+// Write the result back in one shot.
+// CHECK: vector.transfer_write %[[RES_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
// -----
@@ -64,104 +77,115 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
// CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<2x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x8xf32>)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
-// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
-// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+/// Read the whole data in one shot.
+// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+
+
/// w == 0, kw == 0
-// CHECK: %[[V_FILTER_A:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
-// CHECK: %[[V_INPUT0_A:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
-// CHECK: %[[V_OUTPUT_0_A:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32>
+// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x7x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
+/// w == 1, kw == 0
+// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x7x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME: {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
+
/// w == 0, kw == 1
-// CHECK: %[[V_INPUT3_A:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C3]], %[[C0]]], %[[F0]]
-// CHECK: %[[V_OUTPUT_1_A:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]]
+// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32>
+// CHECK: %[[V_INPUT_2:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 2, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x7x3xf32> to vector<4x1x3xf32>
/// w == 1, kw == 0
-// CHECK: %[[V_FILTER_B:.+]] = vector.transfer_read %[[FILTER]][%[[C1]], %[[C0]], %[[C0]]], %[[F0]]
-// CHECK: %[[V_INPUT0_B:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C2]], %[[C0]]], %[[F0]]
-// CHECK: %[[V_OUTPUT_0_B:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
-/// w == 1, kw == 1
-// CHECK: %[[V_INPUT3_B:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C5]], %[[C0]]], %[[F0]]
-// CHECK: %[[V_OUTPUT_1_B:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]]
+// CHECK: %[[V_INPUT_3:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 5, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x7x3xf32> to vector<4x1x3xf32>
/// w == 0, kw == 0
-// CHECK: %[[CONTRACT0_A:.+]] = vector.contract {
+// CHECK: %[[CONTRACT_0:.+]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
-// CHECK-SAME: %[[V_INPUT0_A]], %[[V_FILTER_A]], %[[V_OUTPUT_0_A]]
+// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_0]]
// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
-/// w == 0, kw == 1
-// CHECK: %[[CONTRACT1_A:.+]] = vector.contract {
+/// w == 1, kw == 0
+// CHECK: %[[CONTRACT_1:.+]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
-// CHECK-SAME: %[[V_INPUT3_A]], %[[V_FILTER_A]], %[[V_OUTPUT_1_A]]
+// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER_0]], %[[V_OUTPUT_1]]
// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
-/// w == 1, kw == 0
-// CHECK: %[[CONTRACT0_B:.+]] = vector.contract {
+/// w == 1, kw == 1
+// CHECK: %[[CONTRACT_2:.+]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
-// CHECK-SAME: %[[V_INPUT0_B]], %[[V_FILTER_B]], %[[V_OUTPUT_0_B]]
+// CHECK-SAME: %[[V_INPUT_2]], %[[V_FILTER_1]], %[[CONTRACT_0]]
// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
/// w == 1, kw == 1
-// CHECK: %[[CONTRACT1_B:.+]] = vector.contract {
+// CHECK: %[[CONTRACT_3:.+]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
-// CHECK-SAME: %[[V_INPUT3_B]], %[[V_FILTER_B]], %[[V_OUTPUT_1_B]]
+// CHECK-SAME: %[[V_INPUT_3]], %[[V_FILTER_1]], %[[CONTRACT_1]]
// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
/// w == 0, kw == 0
-// CHECK: vector.transfer_write %[[CONTRACT0_A]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
-/// w == 0, kw == 1
-// CHECK: vector.transfer_write %[[CONTRACT1_A]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]]
+// CHECK: %[[RES_0:.+]] = vector.insert_strided_slice %[[CONTRACT_2]], %[[V_OUTPUT_R]]
+// CHECK-SAME: {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32>
/// w == 1, kw == 0
-// CHECK: vector.transfer_write %[[CONTRACT0_B]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
-/// w == 1, kw == 1
-// CHECK: vector.transfer_write %[[CONTRACT1_B]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]]
+// CHECK: %[[RES_1:.+]] = vector.insert_strided_slice %[[CONTRACT_3]], %[[RES_0]]
+// CHECK-SAME: {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32>
+
+// Write the result back in one shot.
+// CHECK: vector.transfer_write %[[RES_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
// -----
+func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf32>, %output: memref<4x2x8xf32>) {
+ linalg.conv_1d_nwc_wcf
+ {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+ ins(%input, %filter : memref<4x6x3xf32>, memref<2x3x8xf32>)
+ outs(%output : memref<4x2x8xf32>)
+ return
+}
+
// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// CHECK: func @conv1d_nwc_4x2x8_memref
// CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<2x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x8xf32>)
-func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf32>, %output: memref<4x2x8xf32>) {
+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+/// Read the whole data in one shot.
+// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+
/// w == 0, kw == 0
-// CHECK: %[[V_FILTER_000:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<3x8xf32>
-// CHECK: %[[V_INPUT_000:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x3xf32>
-// CHECK: %[[V_OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x8xf32>
+// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32>
+// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x5x3xf32> to vector<4x2x3xf32>
/// w == 0, kw == 1
-// CHECK: %[[V_FILTER_100:.+]] = vector.transfer_read %[[FILTER]][%[[C1]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<3x8xf32>
-// CHECK: %[[V_INPUT_020:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C2]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x3xf32>
-// CHECK: %[[V_OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x8xf32>
+// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32>
+// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 2, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x5x3xf32> to vector<4x2x3xf32>
/// w == 0, kw == 0
-// CHECK: %[[CONTRACT0:.+]] = vector.contract {
+// CHECK: %[[CONTRACT_0:.+]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
-// CHECK-SAME: %[[V_INPUT_000]], %[[V_FILTER_000]], %[[V_OUTPUT_0]]
+// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]]
// CHECK-SAME: : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32>
-/// w == 0, kw == 1
-// CHECK: %[[CONTRACT1:.+]] = vector.contract {
+/// w == 1, kw == 1
+// CHECK: %[[CONTRACT_1:.+]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
-// CHECK-SAME: %[[V_INPUT_020]], %[[V_FILTER_100]], %[[V_OUTPUT_1]]
+// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER_1]], %[[CONTRACT_0]]
// CHECK-SAME: : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32>
-/// w == 0, kw == 0
-// CHECK: vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
-/// w == 0, kw == 1
-// CHECK: vector.transfer_write %[[CONTRACT1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
- linalg.conv_1d_nwc_wcf
- {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
- ins(%input, %filter : memref<4x6x3xf32>, memref<2x3x8xf32>)
- outs(%output : memref<4x2x8xf32>)
- return
-}
+// Write the result back in one shot.
+// CHECK: vector.transfer_write %[[CONTRACT_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
More information about the Mlir-commits
mailing list