[Mlir-commits] [mlir] 392e16c - [mlir][Linalg] NFC - Cleanup conv1d generators
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Jan 17 09:39:34 PST 2022
Author: Nicolas Vasilache
Date: 2022-01-17T17:39:19Z
New Revision: 392e16c27ffce3c88d6d14050327bb7409756900
URL: https://github.com/llvm/llvm-project/commit/392e16c27ffce3c88d6d14050327bb7409756900
DIFF: https://github.com/llvm/llvm-project/commit/392e16c27ffce3c88d6d14050327bb7409756900.diff
LOG: [mlir][Linalg] NFC - Cleanup conv1d generators
Differential Revision: https://reviews.llvm.org/D117330
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorize-convolution.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 3daf243ce472..db3e7197a0aa 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1287,6 +1287,22 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
//===----------------------------------------------------------------------===//
// Convolution vectorization patterns
//===----------------------------------------------------------------------===//
+
+template <int N>
+static void bindShapeDims(ShapedType shapedType) {}
+
+template <int N, typename IntTy, typename... IntTy2>
+static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
+ val = shapedType.getShape()[N];
+ bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
+}
+
+/// Bind a pack of int& to the leading dimensions of shapedType.getShape().
+template <typename... IntTy>
+static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
+ bindShapeDims<0>(shapedType, vals...);
+}
+
namespace {
/// Generate a vector implementation for either:
/// ```
@@ -1354,11 +1370,11 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
if (!valid)
return failure();
- int nSize = lhsShapedType.getShape()[0];
- int wSize = resShapedType.getShape()[1];
- int cSize = lhsShapedType.getShape()[2];
- int kwSize = rhsShapedType.getShape()[0];
- int fSize = rhsShapedType.getShape()[2];
+ int64_t nSize, wSize, cSize, kwSize, fSize;
+ // kernel{kw, c, f}
+ bindShapeDims(rhsShapedType, kwSize, cSize, fSize);
+ // out{n, w, f}
+ bindShapeDims(resShapedType, nSize, wSize);
vector::TransferWriteOp write;
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
@@ -1398,31 +1414,29 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
//===------------------------------------------------------------------===//
// Unroll along kw and read slices of lhs and rhs.
SmallVector<Value> lhsVals, rhsVals, resVals;
+ // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0].
for (int64_t kw = 0; kw < kwSize; ++kw) {
- // Extract rhs slice of size {c, f} @ [kw].
- rhsVals.push_back(builder.create<vector::ExtractOp>(
- loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
-
for (int64_t w = 0; w < wSize; w += wSizeStep) {
- // Extract lhs slice of size {n, wSizeStep, c}
- // @ [0, sw * w + dw * kw, 0].
lhsVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
loc, lhs,
/*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
/*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
/*strides=*/ArrayRef<int64_t>{1, 1, 1}));
-
- // This does not depend on kw.
- if (kw == 0) {
- // Extract res slice: {n, wSizeStep, f} @ [0, w, 0].
- resVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
- loc, res,
- /*offsets=*/ArrayRef<int64_t>{0, w, 0},
- /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, fSize},
- /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
- }
}
}
+ // Extract rhs slice of size {c, f} @ [kw].
+ for (int64_t kw = 0; kw < kwSize; ++kw) {
+ rhsVals.push_back(builder.create<vector::ExtractOp>(
+ loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
+ }
+ // Extract res slice: {n, wSizeStep, f} @ [0, w, 0].
+ for (int64_t w = 0; w < wSize; w += wSizeStep) {
+ resVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
+ loc, res,
+ /*offsets=*/ArrayRef<int64_t>{0, w, 0},
+ /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, fSize},
+ /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
+ }
auto linearIndex = [&](int64_t kw, int64_t w) {
return kw * (wSize / wSizeStep) + w;
@@ -1476,14 +1490,15 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
/// kw is always unrolled.
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
/// > 1.
- FailureOr<Operation *> dilatedConv() {
+ FailureOr<Operation *> depthwiseConv() {
if (!valid)
return failure();
- int nSize = lhsShapedType.getShape()[0];
- int wSize = resShapedType.getShape()[1];
- int cSize = lhsShapedType.getShape()[2];
- int kwSize = rhsShapedType.getShape()[0];
+ int64_t nSize, wSize, cSize, kwSize;
+ // kernel{kw, c}
+ bindShapeDims(rhsShapedType, kwSize, cSize);
+ // out{n, w, c}
+ bindShapeDims(resShapedType, nSize, wSize);
vector::TransferWriteOp write;
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
@@ -1522,31 +1537,30 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
//===------------------------------------------------------------------===//
// Unroll along kw and read slices of lhs and rhs.
SmallVector<Value> lhsVals, rhsVals, resVals;
+ // Extract lhs slice of size {n, wSizeStep, c}
+ // @ [0, sw * w + dw * kw, 0].
for (int64_t kw = 0; kw < kwSize; ++kw) {
- // Extract rhs slice of size {c} @ [kw].
- rhsVals.push_back(builder.create<vector::ExtractOp>(
- loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
-
for (int64_t w = 0; w < wSize; w += wSizeStep) {
- // Extract lhs slice of size {n, wSizeStep, c}
- // @ [0, sw * w + dw * kw, 0].
lhsVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
loc, lhs,
/*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
/*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
/*strides=*/ArrayRef<int64_t>{1, 1, 1}));
-
- // This does not depend on kw.
- if (kw == 0) {
- // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
- resVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
- loc, res,
- /*offsets=*/ArrayRef<int64_t>{0, w, 0},
- /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
- /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
- }
}
}
+ // Extract rhs slice of size {c} @ [kw].
+ for (int64_t kw = 0; kw < kwSize; ++kw) {
+ rhsVals.push_back(builder.create<vector::ExtractOp>(
+ loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
+ }
+ // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
+ for (int64_t w = 0; w < wSize; w += wSizeStep) {
+ resVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
+ loc, res,
+ /*offsets=*/ArrayRef<int64_t>{0, w, 0},
+ /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
+ /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
+ }
auto linearIndex = [&](int64_t kw, int64_t w) {
return kw * (wSize / wSizeStep) + w;
@@ -1555,7 +1569,7 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
// Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
for (int64_t kw = 0; kw < kwSize; ++kw) {
for (int64_t w = 0; w < wSize; w += wSizeStep) {
- resVals[w] = dilatedConv1dSliceAsFma(
+ resVals[w] = depthwiseConv1dSliceAsFma(
builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]);
}
}
@@ -1580,8 +1594,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
}
/// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to fma.
- Value dilatedConv1dSliceAsFma(OpBuilder &b, Location loc, Value lhs,
- Value rhs, Value res) {
+ Value depthwiseConv1dSliceAsFma(OpBuilder &b, Location loc, Value lhs,
+ Value rhs, Value res) {
Value bcast = builder.create<vector::BroadcastOp>(loc, res.getType(), rhs);
return b.create<vector::FMAOp>(loc, lhs, bcast, res);
}
@@ -1614,7 +1628,7 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
/*rhsIndex*/ {kw, c},
/*resIndex*/ {n, w, c}}))
- return dilatedConv();
+ return depthwiseConv();
return failure();
}
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index e545e7f191aa..3168e7f67d51 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -23,15 +23,15 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x3x8xf
// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
-// CHECK: %[[V_FILTER:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<1x3x8xf32>
-/// w == 0, kw == 0
// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
-// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
-// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
-/// w == 1, kw == 0
// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
+
+// CHECK: %[[V_FILTER:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<1x3x8xf32>
+
+// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
// CHECK: %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
// CHECK-SAME: {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
@@ -84,27 +84,23 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
-
-/// w == 0, kw == 0
-// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32>
// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
-// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
-// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
-/// w == 1, kw == 0
// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
-// CHECK: %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
-// CHECK-SAME: {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
-
-/// w == 0, kw == 1
-// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32>
// CHECK: %[[V_INPUT_2:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 2, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
-/// w == 1, kw == 0
// CHECK: %[[V_INPUT_3:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 5, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32>
+// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32>
+
+// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
+// CHECK: %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME: {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
+
/// w == 0, kw == 0
// CHECK: %[[CONTRACT_0:.+]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
@@ -165,15 +161,14 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
-/// w == 0, kw == 0
-// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32>
// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32>
-/// w == 0, kw == 1
-// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32>
// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 2, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32>
+// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32>
+// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32>
+
/// w == 0, kw == 0
// CHECK: %[[CONTRACT_0:.+]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
@@ -211,15 +206,14 @@ func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filter: m
// CHECK: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
// CHECK: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
-/// w == 0, kw == 0
-// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x4xf32>
// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
-/// w == 0, kw == 1
-// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x4xf32>
// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
+// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x4xf32>
+// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x4xf32>
+
/// w == 0, kw = 0
// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xf32> to vector<3x2x4xf32>
// CHECK: %[[FMA_0:.*]] = vector.fma %[[V_INPUT_0]], %[[B_FILTER_0]], %[[V_OUTPUT_R]] : vector<3x2x4xf32>
More information about the Mlir-commits
mailing list