[Mlir-commits] [mlir] 7b09f15 - [mlir][Linalg] Refactor conv vectorization to decouple memory from vector ops.
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Nov 3 01:09:39 PDT 2021
Author: Nicolas Vasilache
Date: 2021-11-03T08:03:40Z
New Revision: 7b09f157e1748597858bbeb392c6c95debb7ba0f
URL: https://github.com/llvm/llvm-project/commit/7b09f157e1748597858bbeb392c6c95debb7ba0f
DIFF: https://github.com/llvm/llvm-project/commit/7b09f157e1748597858bbeb392c6c95debb7ba0f.diff
LOG: [mlir][Linalg] Refactor conv vectorization to decouple memory from vector ops.
This refactoring prepares conv1d vectorization for a future integration into
the generic codegen path.
Once transfer_read / transfer_write vectorization also supports sliding windows,
the special pattern for conv can disappear.
This will also likely need a vector.conv operation.
Differential Revision: https://reviews.llvm.org/D112797
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorize-convolution.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index b7520f1a62fa3..0678563cb2d3f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1396,8 +1396,7 @@ namespace {
/// Iters: ({Par(), Par(), Par(), Red(), Red()})
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
/// ```
-/// w and kw are unrolled.
-/// TODO: do not unroll w (resp. kw) when the strideW ( resp. dilationW) is > 1.
+/// kw is unrolled, w is unrolled iff dilationW > 1.
struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
Conv1D_NWC_WCF_Generator(OpBuilder &builder, LinalgOp linalgOp, int strideW,
int dilationW)
@@ -1455,52 +1454,58 @@ struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
vector::TransferWriteOp write;
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
+ // When strideW == 1, we can batch the contiguous loads and avoid unrolling
int64_t wSizeStep = strideW == 1 ? wSize : 1;
+ VectorType lhsType = VectorType::get({nSize, wSizeStep, cSize},
+ lhsShapedType.getElementType());
+ VectorType rhsType =
+ VectorType::get({cSize, fSize}, rhsShapedType.getElementType());
+ VectorType resType = VectorType::get({nSize, wSizeStep, fSize},
+ resShapedType.getElementType());
+
+ SmallVector<Value> lhsVals, rhsVals, resVals;
// Unroll along kw and read slices of lhs and rhs.
// Alternatively we could preload both 3-d slices and extract smaller slices
// iteratively without touching memory. But this will quickly spill.
for (int64_t kw = 0; kw < kwSize; ++kw) {
// Read rhs slice of size {c, f} @ [kw, 0, 0].
Value kwVal = builder.create<arith::ConstantIndexOp>(loc, kw);
- VectorType rhsType =
- VectorType::get({cSize, fSize}, rhsShapedType.getElementType());
- Value rhs = builder.create<vector::TransferReadOp>(
- loc, rhsType, rhsShaped, ValueRange{kwVal, zero, zero});
+ rhsVals.push_back(builder.create<vector::TransferReadOp>(
+ loc, rhsType, rhsShaped, ValueRange{kwVal, zero, zero}));
for (int64_t w_iv = 0; w_iv < wSize; w_iv += wSizeStep) {
// Read lhs slice of size {n, wSizeStep, c}
// @ [0, sw * w + dw * kw, 0].
Value lhsStridedIdx = builder.create<arith::ConstantIndexOp>(
loc, strideW * w_iv + dilationW * kw);
- VectorType lhsType = VectorType::get({nSize, wSizeStep, cSize},
- lhsShapedType.getElementType());
- Value lhs = builder.create<vector::TransferReadOp>(
- loc, lhsType, lhsShaped, ValueRange{zero, lhsStridedIdx, zero});
+ lhsVals.push_back(builder.create<vector::TransferReadOp>(
+ loc, lhsType, lhsShaped, ValueRange{zero, lhsStridedIdx, zero}));
// Read res slice: {n, wSizeStep, f} @ [0, w, 0].
Value wVal = builder.create<arith::ConstantIndexOp>(loc, w_iv);
- VectorType resType = VectorType::get({nSize, wSizeStep, fSize},
- resShapedType.getElementType());
// When operating on tensors, reading from the updated value is required
// for vector.transfer_read/write hoisting to function as expected.
- Value res = builder.create<vector::TransferReadOp>(
- loc, resType, resShaped, ValueRange{zero, wVal, zero});
-
+ resVals.push_back(builder.create<vector::TransferReadOp>(
+ loc, resType, resShaped, ValueRange{zero, wVal, zero}));
+ }
+ }
+ for (int64_t kw = 0; kw < kwSize; ++kw) {
+ for (int64_t w_iv = 0; w_iv < wSize; w_iv += wSizeStep) {
// Compute contraction: I{n, w, c} * F{c, f} -> O{n, w, f}
- StringRef par = Par().strRef, red = Red().strRef;
- AffineExpr n, w, f, c;
- bindDims(ctx, n, w, f, c);
- // clang-format off
- res = builder.create<vector::ContractionOp>(
- loc, lhs, rhs, res,
- /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
- /*iteratorTypes=*/ArrayRef<StringRef>{par, par, par, red});
- // clang-format on
-
+ resVals[kw * (wSize / wSizeStep) + w_iv] = conv1dSliceAsContraction(
+ builder, loc, lhsVals[kw * (wSize / wSizeStep) + w_iv], rhsVals[kw],
+ resVals[kw * (wSize / wSizeStep) + w_iv]);
+ }
+ }
+ for (int64_t kw = 0; kw < kwSize; ++kw) {
+ for (int64_t w_iv = 0; w_iv < wSize; w_iv += wSizeStep) {
+ Value wVal = builder.create<arith::ConstantIndexOp>(loc, w_iv);
// Write back res slice: {n, wSizeStep, f} @ [0, w, 0].
write = builder.create<vector::TransferWriteOp>(
- loc, res, resShaped, ValueRange{zero, wVal, zero});
+ loc, resVals[kw * (wSize / wSizeStep) + w_iv], resShaped,
+ ValueRange{zero, wVal, zero});
if (write.getNumResults() == 1)
resShaped = write->getResult(0);
}
@@ -1509,6 +1514,19 @@ struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
return write.getOperation();
}
+ // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
+ vector::ContractionOp conv1dSliceAsContraction(OpBuilder &b, Location loc,
+ Value lhs, Value rhs,
+ Value res) {
+ StringRef par = Par().strRef, red = Red().strRef;
+ AffineExpr n, w, f, c;
+ bindDims(ctx, n, w, f, c);
+ return builder.create<vector::ContractionOp>(
+ loc, lhs, rhs, res,
+ /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
+ /*iteratorTypes=*/ArrayRef<StringRef>{par, par, par, red});
+ }
+
/// Entry point that transposes into the common form:
/// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
FailureOr<Operation *> generateConv() {
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index b1802fded0b13..31be54ec219c5 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -24,21 +24,26 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x3x8xf
// CHECK: %[[V_FILTER:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
// CHECK: %[[V_INPUT0:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
// CHECK: %[[V_OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+/// w == 1, kw == 0
+// CHECK: %[[V_INPUT3:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C3]], %[[C0]]], %[[F0]]
+// CHECK: %[[V_OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]]
// CHECK: %[[CONTRACT0:.+]] = vector.contract {
+
+/// w == 0, kw == 0
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
// CHECK-SAME: %[[V_INPUT0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
-// CHECK: vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
-
/// w == 1, kw == 0
-// CHECK: %[[V_INPUT3:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C3]], %[[C0]]], %[[F0]]
-// CHECK: %[[V_OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]]
// CHECK: %[[CONTRACT1:.+]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
// CHECK-SAME: %[[V_INPUT3]], %[[V_FILTER]], %[[V_OUTPUT_1]]
// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
+
+/// w == 0, kw == 0
+// CHECK: vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+/// w == 1, kw == 0
// CHECK: vector.transfer_write %[[CONTRACT1]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]]
// -----
@@ -69,48 +74,53 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
// CHECK: %[[V_FILTER_A:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
// CHECK: %[[V_INPUT0_A:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
// CHECK: %[[V_OUTPUT_0_A:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+/// w == 0, kw == 1
+// CHECK: %[[V_INPUT3_A:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C3]], %[[C0]]], %[[F0]]
+// CHECK: %[[V_OUTPUT_1_A:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]]
+/// w == 1, kw == 0
+// CHECK: %[[V_FILTER_B:.+]] = vector.transfer_read %[[FILTER]][%[[C1]], %[[C0]], %[[C0]]], %[[F0]]
+// CHECK: %[[V_INPUT0_B:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C2]], %[[C0]]], %[[F0]]
+// CHECK: %[[V_OUTPUT_0_B:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+/// w == 1, kw == 1
+// CHECK: %[[V_INPUT3_B:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C5]], %[[C0]]], %[[F0]]
+// CHECK: %[[V_OUTPUT_1_B:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]]
+
+/// w == 0, kw == 0
// CHECK: %[[CONTRACT0_A:.+]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
// CHECK-SAME: %[[V_INPUT0_A]], %[[V_FILTER_A]], %[[V_OUTPUT_0_A]]
// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
-// CHECK: vector.transfer_write %[[CONTRACT0_A]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
-
/// w == 0, kw == 1
-// CHECK: %[[V_INPUT3_A:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C3]], %[[C0]]], %[[F0]]
-// CHECK: %[[V_OUTPUT_1_A:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]]
// CHECK: %[[CONTRACT1_A:.+]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
// CHECK-SAME: %[[V_INPUT3_A]], %[[V_FILTER_A]], %[[V_OUTPUT_1_A]]
// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
-// CHECK: vector.transfer_write %[[CONTRACT1_A]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]]
-
/// w == 1, kw == 0
-// CHECK: %[[V_FILTER_B:.+]] = vector.transfer_read %[[FILTER]][%[[C1]], %[[C0]], %[[C0]]], %[[F0]]
-// CHECK: %[[V_INPUT0_B:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C2]], %[[C0]]], %[[F0]]
-// CHECK: %[[V_OUTPUT_0_B:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
// CHECK: %[[CONTRACT0_B:.+]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
// CHECK-SAME: %[[V_INPUT0_B]], %[[V_FILTER_B]], %[[V_OUTPUT_0_B]]
// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
-// CHECK: vector.transfer_write %[[CONTRACT0_B]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
-
/// w == 1, kw == 1
-// CHECK: %[[V_INPUT3_B:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C5]], %[[C0]]], %[[F0]]
-// CHECK: %[[V_OUTPUT_1_B:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]]
// CHECK: %[[CONTRACT1_B:.+]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
// CHECK-SAME: %[[V_INPUT3_B]], %[[V_FILTER_B]], %[[V_OUTPUT_1_B]]
// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
+
+/// w == 0, kw == 0
+// CHECK: vector.transfer_write %[[CONTRACT0_A]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+/// w == 0, kw == 1
+// CHECK: vector.transfer_write %[[CONTRACT1_A]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]]
+/// w == 1, kw == 0
+// CHECK: vector.transfer_write %[[CONTRACT0_B]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+/// w == 1, kw == 1
// CHECK: vector.transfer_write %[[CONTRACT1_B]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]]
// -----
-
-
// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
@@ -127,22 +137,27 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
// CHECK: %[[V_FILTER_000:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<3x8xf32>
// CHECK: %[[V_INPUT_000:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x3xf32>
// CHECK: %[[V_OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x8xf32>
+/// w == 0, kw == 1
+// CHECK: %[[V_FILTER_100:.+]] = vector.transfer_read %[[FILTER]][%[[C1]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<3x8xf32>
+// CHECK: %[[V_INPUT_020:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C2]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x3xf32>
+// CHECK: %[[V_OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x8xf32>
+
+/// w == 0, kw == 0
// CHECK: %[[CONTRACT0:.+]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
// CHECK-SAME: %[[V_INPUT_000]], %[[V_FILTER_000]], %[[V_OUTPUT_0]]
// CHECK-SAME: : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32>
-// CHECK: vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
-
/// w == 0, kw == 1
-// CHECK: %[[V_FILTER_100:.+]] = vector.transfer_read %[[FILTER]][%[[C1]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<3x8xf32>
-// CHECK: %[[V_INPUT_020:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C2]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x3xf32>
-// CHECK: %[[V_OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x8xf32>
// CHECK: %[[CONTRACT1:.+]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
// CHECK-SAME: %[[V_INPUT_020]], %[[V_FILTER_100]], %[[V_OUTPUT_1]]
// CHECK-SAME: : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32>
+
+/// w == 0, kw == 0
+// CHECK: vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+/// w == 0, kw == 1
// CHECK: vector.transfer_write %[[CONTRACT1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
linalg.conv_1d_nwc_wcf
{dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
More information about the Mlir-commits
mailing list