[Mlir-commits] [mlir] 99ff697 - [mlir][Vector] Add support for 1D depthwise conv vectorization
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Nov 12 05:15:46 PST 2021
Author: Nicolas Vasilache
Date: 2021-11-12T13:14:09Z
New Revision: 99ff697bf72af978515ecca833337965502d4e63
URL: https://github.com/llvm/llvm-project/commit/99ff697bf72af978515ecca833337965502d4e63
DIFF: https://github.com/llvm/llvm-project/commit/99ff697bf72af978515ecca833337965502d4e63.diff
LOG: [mlir][Vector] Add support for 1D depthwise conv vectorization
At this time the 2 flavors of conv are a little too different to allow significant code sharing and other will likely come up.
so we go the easy route first by duplicating and adapting.
Reviewed By: gysit
Differential Revision: https://reviews.llvm.org/D113758
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorize-convolution.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c65d2a1de869..9f2b798c12f0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1390,16 +1390,25 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
// Convolution vectorization patterns
//===----------------------------------------------------------------------===//
namespace {
-/// Generate a vector implementation for:
+/// Generate a vector implementation for either:
/// ```
/// Op def: ( n, w, c, kw, f )
/// Iters: ({Par(), Par(), Par(), Red(), Red()})
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
/// ```
/// kw is unrolled, w is unrolled iff dilationW > 1.
-struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
- Conv1D_NWC_WCF_Generator(OpBuilder &builder, LinalgOp linalgOp, int strideW,
- int dilationW)
+///
+/// or
+///
+/// ```
+/// Op def: ( n, w, c, kw )
+/// Iters: ({Par(), Par(), Par(), Red()})
+/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
+/// ```
+/// kw is unrolled, w is unrolled iff dilationW > 1.
+struct Conv1D_NWC_Generator : public StructuredGenerator<LinalgOp> {
+ Conv1D_NWC_Generator(OpBuilder &builder, LinalgOp linalgOp, int strideW,
+ int dilationW)
: StructuredGenerator<LinalgOp>(builder, linalgOp), valid(false),
strideW(strideW), dilationW(dilationW) {
// Determine whether `linalgOp` can be generated with this generator
@@ -1413,7 +1422,8 @@ struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
resShapedType = resShaped.getType().dyn_cast<ShapedType>();
if (!lhsShapedType || !rhsShapedType || !resShapedType)
return;
- if (lhsShapedType.getRank() != 3 || rhsShapedType.getRank() != 3 ||
+ if (lhsShapedType.getRank() != 3 ||
+ (rhsShapedType.getRank() != 2 && rhsShapedType.getRank() != 3) ||
resShapedType.getRank() != 3)
return;
@@ -1553,12 +1563,130 @@ struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
/*iteratorTypes=*/ArrayRef<StringRef>{par, par, par, red});
}
+ /// Generate a vector implementation for:
+ /// ```
+ /// Op def: ( n, w, c, kw)
+ /// Iters: ({Par(), Par(), Par(), Red()})
+ /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
+ /// ```
+ /// kw is always unrolled.
+ /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1.
+ FailureOr<Operation *> dilated_conv() {
+ if (!valid)
+ return failure();
+
+ int nSize = lhsShapedType.getShape()[0];
+ int wSize = resShapedType.getShape()[1];
+ int cSize = lhsShapedType.getShape()[2];
+ int kwSize = rhsShapedType.getShape()[0];
+
+ vector::TransferWriteOp write;
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+
+ // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
+ // When strideW == 1, we can batch the contiguous loads and avoid unrolling
+ int64_t wSizeStep = strideW == 1 ? wSize : 1;
+
+ Type lhsEltType = lhsShapedType.getElementType();
+ Type rhsEltType = rhsShapedType.getElementType();
+ Type resEltType = resShapedType.getElementType();
+ VectorType lhsType = VectorType::get(
+ {nSize, (wSize - 1) * strideW + 1 + (kwSize - 1) * dilationW + 1,
+ cSize},
+ lhsEltType);
+ VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
+ VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
+
+ // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, 0].
+ Value lhs = builder.create<vector::TransferReadOp>(
+ loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
+ // Read rhs slice of size {kw, c} @ [0, 0].
+ Value rhs = builder.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
+ ValueRange{zero, zero});
+ // Read res slice of size {n, w, c} @ [0, 0, 0].
+ Value res = builder.create<vector::TransferReadOp>(
+ loc, resType, resShaped, ValueRange{zero, zero, zero});
+
+ //===------------------------------------------------------------------===//
+ // Begin vector-only rewrite part
+ //===------------------------------------------------------------------===//
+ // Unroll along kw and read slices of lhs and rhs.
+ SmallVector<Value> lhsVals, rhsVals, resVals;
+ for (int64_t kw = 0; kw < kwSize; ++kw) {
+ // Extract rhs slice of size {c} @ [kw].
+ rhsVals.push_back(builder.create<vector::ExtractOp>(
+ loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
+
+ for (int64_t w = 0; w < wSize; w += wSizeStep) {
+ // Extract lhs slice of size {n, wSizeStep, c}
+ // @ [0, sw * w + dw * kw, 0].
+ lhsVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
+ loc, lhs,
+ /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
+ /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
+ /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
+
+ // This does not depend on kw.
+ if (kw == 0) {
+ // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
+ resVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
+ loc, res,
+ /*offsets=*/ArrayRef<int64_t>{0, w, 0},
+ /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
+ /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
+ }
+ }
+ }
+
+ auto linearIndex = [&](int64_t kw, int64_t w) {
+ return kw * (wSize / wSizeStep) + w;
+ };
+
+ // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
+ for (int64_t kw = 0; kw < kwSize; ++kw) {
+ for (int64_t w = 0; w < wSize; w += wSizeStep) {
+ resVals[w] = dilatedConv1dSliceAsContraction(
+ builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]);
+ }
+ }
+
+ // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
+ // This does not depend on kw.
+ for (int64_t w = 0; w < wSize; w += wSizeStep) {
+ res = builder.create<vector::InsertStridedSliceOp>(
+ loc, resVals[w], res,
+ /*offsets=*/ArrayRef<int64_t>{0, w, 0},
+ /*strides=*/ArrayRef<int64_t>{1, 1, 1});
+ }
+ //===------------------------------------------------------------------===//
+ // End vector-only rewrite part
+ //===------------------------------------------------------------------===//
+
+ // Write back res slice of size {n, w, c} @ [0, 0, 0].
+ return builder
+ .create<vector::TransferWriteOp>(loc, res, resShaped,
+ ValueRange{zero, zero, zero})
+ .getOperation();
+ }
+
+ // Create a contraction: lhs{n, w, c} * rhs{c} -> res{n, w, c}
+ vector::ContractionOp dilatedConv1dSliceAsContraction(OpBuilder &b,
+ Location loc, Value lhs,
+ Value rhs, Value res) {
+ StringRef par = Par().strRef, red = Red().strRef;
+ AffineExpr n, w, c;
+ bindDims(ctx, n, w, c);
+ return builder.create<vector::ContractionOp>(
+ loc, lhs, rhs, res,
+ /*indexingMaps=*/MapList{{n, w, c}, {c}, {n, w, c}},
+ /*iteratorTypes=*/ArrayRef<StringRef>{par, par, red});
+ }
+
/// Entry point that transposes into the common form:
/// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
FailureOr<Operation *> generateConv() {
AffineExpr n, w, f, kw, c;
bindDims(ctx, n, w, f, kw, c);
-
if (!iters({Par(), Par(), Par(), Red(), Red()}))
return failure();
@@ -1570,6 +1698,22 @@ struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
return failure();
}
+ /// Entry point that transposes into the common form:
+ /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
+ FailureOr<Operation *> generateDilatedConv() {
+ AffineExpr n, w, c, kw;
+ bindDims(ctx, n, w, c, kw);
+ if (!iters({Par(), Par(), Par(), Red()}))
+ return failure();
+
+ // No transposition needed.
+ if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
+ /*rhsIndex*/ {kw, c},
+ /*resIndex*/ {n, w, c}}))
+ return dilated_conv();
+ return failure();
+ }
+
private:
bool valid;
int strideW, dilationW;
@@ -1588,8 +1732,11 @@ vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp) {
auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
LinalgOp linalgOp = cast<LinalgOp>(convOp.getOperation());
- Conv1D_NWC_WCF_Generator e(b, linalgOp, stride, dilation);
- return e.generateConv();
+ Conv1D_NWC_Generator e(b, linalgOp, stride, dilation);
+ auto res = e.generateConv();
+ if (succeeded(res))
+ return res;
+ return e.generateDilatedConv();
}
struct VectorizeConvolution
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index 0a1cbc41d58e..aa3d3f55953c 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -180,7 +180,7 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]]
// CHECK-SAME: : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32>
-/// w == 1, kw == 1
+/// w == 0, kw == 1
// CHECK: %[[CONTRACT_1:.+]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
@@ -189,3 +189,52 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
// Write the result back in one shot.
// CHECK: vector.transfer_write %[[CONTRACT_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+// -----
+
+func @depthwise_conv1d_nwc_3x5x4_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
+ linalg.depthwise_conv1D_nw
+ {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+ ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>)
+ outs(%output : memref<3x2x4xf32>)
+ return
+}
+
+// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+
+// CHECK: func @depthwise_conv1d_nwc_3x5x4_memref
+// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xf32>, %[[FILTER:[0-9a-z]+]]: memref<2x4xf32>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xf32>)
+
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+
+/// Read the whole data in one shot.
+// CHECK: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
+// CHECK: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+/// w == 0, kw == 0
+// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x4xf32>
+// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x5x4xf32> to vector<3x2x4xf32>
+/// w == 0, kw == 1
+// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x4xf32>
+// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x5x4xf32> to vector<3x2x4xf32>
+
+/// w == 0, kw == 0
+// CHECK: %[[CONTRACT_0:.+]] = vector.contract {
+// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[INPUT_MAP]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]]
+// CHECK-SAME: : vector<3x2x4xf32>, vector<4xf32> into vector<3x2x4xf32>
+/// w == 0, kw == 1
+// CHECK: %[[CONTRACT_1:.+]] = vector.contract {
+// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[INPUT_MAP]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER_1]], %[[CONTRACT_0]]
+// CHECK-SAME: : vector<3x2x4xf32>, vector<4xf32> into vector<3x2x4xf32>
+
+// Write the result back in one shot.
+// CHECK: vector.transfer_write %[[CONTRACT_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
More information about the Mlir-commits
mailing list