[Mlir-commits] [mlir] 0b17d47 - [mlir][Linalg] Tile sizes for Conv ops vectorization added as pass arguments
Jakub Lichman
llvmlistbot at llvm.org
Wed Sep 30 04:31:54 PDT 2020
Author: Jakub Lichman
Date: 2020-09-30T11:31:28Z
New Revision: 0b17d4754a94b7129c2483762acd586783802b12
URL: https://github.com/llvm/llvm-project/commit/0b17d4754a94b7129c2483762acd586783802b12
DIFF: https://github.com/llvm/llvm-project/commit/0b17d4754a94b7129c2483762acd586783802b12.diff
LOG: [mlir][Linalg] Tile sizes for Conv ops vectorization added as pass arguments
Current setup for conv op vectorization does not enable user to specify tile
sizes as well as dimensions for vectorization. In this commit we change that by
adding tile sizes as pass arguments. Every dimension with corresponding tile
size > 1 is automatically vectorized.
Differential Revision: https://reviews.llvm.org/D88533
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir
mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir
mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir
mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir
mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir
mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir
mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir
mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir
mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
mlir/test/lib/Transforms/TestConvVectorization.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index b188fde5d801..00a094d72076 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -32,7 +32,8 @@ struct TiledLinalgOp {
/// Populates patterns for vectorization of all ConvN-D ops.
void populateConvVectorizationPatterns(
- MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns);
+ MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
+ ArrayRef<int64_t> tileSizes);
/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
/// and permute the loop nest according to `interchangeVector`
@@ -549,8 +550,8 @@ struct AffineMinSCFCanonicalizationPattern
/// false of size 1. This ensures that the ConvOp can be lowered to vector
/// contraction of dimensions marked in the *mask* as true.
///
-/// A good example is ConvNHWCOp which is 2D Conv op with channels as the last
-/// dimension. For this op we contract last 3 dimensions.
+/// A good example for vectorization is ConvNHWCOp which is 2D Conv op
+/// with channels as the last dimension. Let's vectorize last 3 dimensions.
/// The initial op definition looks like this:
/// ```
/// linalg.conv_2d_nhwc %arg0, %arg1, %arg2 :
@@ -589,10 +590,6 @@ class ConvOpVectorization : public OpRewritePattern<ConvOp> {
LogicalResult matchAndRewrite(ConvOp minOp,
PatternRewriter &rewriter) const override;
-
- // TODO: Make these pass arguments.
- static const int tileSize = 3;
- static const int noTile = 1;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir
index 97ea95c8bcd1..7cc0875b3353 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir
@@ -9,13 +9,13 @@
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
-// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \
+// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,3" -convert-linalg-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=4" \
-// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \
+// RUN: -test-conv-vectorization="tile-sizes=1,3" -convert-linalg-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir
index dcfcc9b62bbc..7f90ac675f72 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir
@@ -9,13 +9,13 @@
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
-// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \
+// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,1,3,3" -convert-linalg-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,0,4" \
-// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \
+// RUN: -test-conv-vectorization="tile-sizes=1,1,1,3,3" -convert-linalg-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir
index 2e79b46801bc..3eb0959ddda1 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir
@@ -9,13 +9,13 @@
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
-// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \
+// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,1,3,3" -convert-linalg-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,4" \
-// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \
+// RUN: -test-conv-vectorization="tile-sizes=1,1,1,3,3" -convert-linalg-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir
index e271b0a009b6..787cbf5d268b 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir
@@ -9,13 +9,13 @@
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
-// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \
+// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,3,3" -convert-linalg-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,2" \
-// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \
+// RUN: -test-conv-vectorization="tile-sizes=1,1,3,3" -convert-linalg-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir
index e27c40524fcc..c6236db6a05a 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir
@@ -9,13 +9,13 @@
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
-// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \
+// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,1,1,3,3,3" -convert-linalg-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,0,4,4" \
-// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \
+// RUN: -test-conv-vectorization="tile-sizes=1,1,1,1,3,3,3" -convert-linalg-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir
index b5b4a5c82c09..3213b7dc5fe2 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir
@@ -9,13 +9,13 @@
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
-// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \
+// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,1,1,3,3,3" -convert-linalg-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,3,2" \
-// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \
+// RUN: -test-conv-vectorization="tile-sizes=1,1,1,1,3,3,3" -convert-linalg-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir
index 12ea94696660..8020f3ac017f 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir
@@ -9,13 +9,13 @@
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
-// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \
+// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,1,3,3,3" -convert-linalg-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,2,2" \
-// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \
+// RUN: -test-conv-vectorization="tile-sizes=1,1,1,3,3,3" -convert-linalg-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir
index e36abc83b700..830b5402c2a4 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir
@@ -9,13 +9,13 @@
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
-// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \
+// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,1,1,1,3,3,3,3" -convert-linalg-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,0,5,5,5" \
-// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \
+// RUN: -test-conv-vectorization="tile-sizes=1,1,1,1,1,3,3,3,3" -convert-linalg-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir
index b302b3e0d8bd..0b25ea09157c 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir
@@ -9,13 +9,13 @@
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
-// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \
+// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,1,1,1,3,3,3,3" -convert-linalg-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,5,5,5" \
-// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \
+// RUN: -test-conv-vectorization="tile-sizes=1,1,1,1,1,3,3,3,3" -convert-linalg-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 9a225dd81c79..4430c34af1e9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -385,16 +385,19 @@ LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
return failure();
SmallVector<AffineExpr, 4> mapping;
- // Fail to apply when the size of not vectorized dimension is not 1 or
- // when the size of vectorized dimension is not dimSize.
+ SmallVector<int64_t, 4> vectorDims;
+ // Fail to apply when the size of not vectorized dimension is not 1.
for (unsigned i = 0; i < N; i++) {
if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1))
return failure();
- if (mask[i] && (inShape[i] != tileSize || kShape[i] != tileSize))
+
+ if (mask[i] && inShape[i] != kShape[i])
return failure();
- if (mask[i])
+ if (mask[i]) {
mapping.push_back(getAffineDimExpr(i, context));
+ vectorDims.push_back(inShape[i]);
+ }
}
Value input = op.getInput(0);
@@ -407,8 +410,7 @@ LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
auto map = AffineMap::get(rank, 0, mapping, context);
SmallVector<Value, 4> zeros(rank, std_constant_index(0));
- auto vecType =
- VectorType::get(SmallVector<int64_t, 4>(numDims, tileSize), elemType);
+ auto vecType = VectorType::get(vectorDims, elemType);
auto inputVec = vector_transfer_read(vecType, input, zeros, map);
auto kernelVec = vector_transfer_read(vecType, kernel, zeros, map);
@@ -443,6 +445,9 @@ populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns,
OwningRewritePatternList &vectorizationPatterns,
ArrayRef<int64_t> tileSizes,
MLIRContext *context) {
+ if (tileSizes.size() < N)
+ return;
+
constexpr static StringRef kTiledMarker = "TILED";
constexpr static StringRef kPromotedMarker = "PROMOTED";
tilingPatterns.insert<LinalgTilingPattern<ConvOp>>(
@@ -457,49 +462,41 @@ populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns,
SmallVector<bool, 4> mask(N);
int offset = tileSizes.size() - N;
std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(),
- [](int64_t i) -> bool { return i != ConvOpConst::noTile; });
+ [](int64_t i) -> bool { return i > 1; });
vectorizationPatterns.insert<ConvOpVectorization<ConvOp, N>>(context, mask);
}
void mlir::linalg::populateConvVectorizationPatterns(
- MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns) {
- const int64_t tileSize = ConvOpConst::tileSize;
- const int64_t noTile = ConvOpConst::noTile;
- auto makeTileSizes = [&](unsigned numNoTile, unsigned numTile) {
- SmallVector<int64_t, 10> result(numNoTile, noTile);
- result.append(numTile, tileSize);
- return result;
- };
-
+ MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
+ ArrayRef<int64_t> tileSizes) {
OwningRewritePatternList tiling, promotion, vectorization;
- populateVectorizationPatterns<ConvWOp, 1>(
- tiling, promotion, vectorization,
- makeTileSizes(/*numNoTile=*/1, /*numTile*/ 1), context);
+ populateVectorizationPatterns<ConvWOp, 1>(tiling, promotion, vectorization,
+ tileSizes, context);
populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization,
- makeTileSizes(3, 2), context);
+ tileSizes, context);
populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization,
- makeTileSizes(3, 2), context);
+ tileSizes, context);
populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization,
- makeTileSizes(2, 2), context);
+ tileSizes, context);
populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization,
- makeTileSizes(4, 3), context);
+ tileSizes, context);
populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization,
- makeTileSizes(4, 3), context);
+ tileSizes, context);
populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization,
- makeTileSizes(3, 3), context);
+ tileSizes, context);
populateVectorizationPatterns<ConvNDHWCOp, 5>(
- tiling, promotion, vectorization, makeTileSizes(5, 4), context);
+ tiling, promotion, vectorization, tileSizes, context);
populateVectorizationPatterns<ConvNCDHWOp, 5>(
- tiling, promotion, vectorization, makeTileSizes(5, 4), context);
+ tiling, promotion, vectorization, tileSizes, context);
patterns.push_back(std::move(tiling));
patterns.push_back(std::move(promotion));
diff --git a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
index eeb2ca31fd2a..e1bb7f3caabb 100644
--- a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
+++ b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-conv-vectorization --cse | FileCheck %s
+// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,3" --cse | FileCheck %s
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0)[s0] -> (1, -d0 + s0)>
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
diff --git a/mlir/test/lib/Transforms/TestConvVectorization.cpp b/mlir/test/lib/Transforms/TestConvVectorization.cpp
index c90d8058de32..79b6464f3b4c 100644
--- a/mlir/test/lib/Transforms/TestConvVectorization.cpp
+++ b/mlir/test/lib/Transforms/TestConvVectorization.cpp
@@ -24,6 +24,13 @@ namespace {
/// A pass converting MLIR Linalg ops into Vector ops.
class TestConvVectorization
: public PassWrapper<TestConvVectorization, OperationPass<ModuleOp>> {
+public:
+ TestConvVectorization() = default;
+ TestConvVectorization(const TestConvVectorization &) {}
+ explicit TestConvVectorization(ArrayRef<int64_t> tileSizesParam) {
+ tileSizes = tileSizesParam;
+ }
+
void runOnOperation() override;
void getDependentDialects(DialectRegistry ®istry) const override {
@@ -33,6 +40,10 @@ class TestConvVectorization
registry.insert<AffineDialect>();
registry.insert<StandardOpsDialect>();
}
+
+ ListOption<int64_t> tileSizes{
+ *this, "tile-sizes", llvm::cl::desc("Vectorization sizes."),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
};
} // namespace
@@ -47,7 +58,7 @@ void TestConvVectorization::runOnOperation() {
target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
SmallVector<OwningRewritePatternList, 4> stage1Patterns;
- linalg::populateConvVectorizationPatterns(context, stage1Patterns);
+ linalg::populateConvVectorizationPatterns(context, stage1Patterns, tileSizes);
OwningRewritePatternList stage2Patterns =
linalg::getLinalgTilingCanonicalizationPatterns(context);
More information about the Mlir-commits
mailing list