[Mlir-commits] [mlir] 7b615a8 - [mlir][linalg] Rewrite `linalg.conv_2d_nhwc_hwcf` into 1-D
Lei Zhang
llvmlistbot at llvm.org
Tue Nov 2 07:00:54 PDT 2021
Author: Lei Zhang
Date: 2021-11-02T09:56:26-04:00
New Revision: 7b615a87dc559e24ef377313e00fe339b96c893e
URL: https://github.com/llvm/llvm-project/commit/7b615a87dc559e24ef377313e00fe339b96c893e
DIFF: https://github.com/llvm/llvm-project/commit/7b615a87dc559e24ef377313e00fe339b96c893e.diff
LOG: [mlir][linalg] Rewrite `linalg.conv_2d_nhwc_hwcf` into 1-D
We'd like to take a progressive approach towards Fconvolution op
CodeGen, by 1) tiling it to fit compute hierarchy first, and then
2) tiling along window dimensions with size 1 to reduce the problem
to be matmul-like. After that, we can 3) downscale high-D convolution
ops to low-D by removing the size-1 window dimensions. The final
step would be 4) vectorizing the low-D convolution op directly.
We have patterns for 1), 2), and 4). This commit adds a pattern for
3) for `linalg.conv_2d_nhwc_hwcf` ops as a starter. Supporting other
high-D convolution ops should be similar and mechanical.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D112928
Added:
mlir/test/Dialect/Linalg/decompose-convolution.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index bbaabac10cdb0..ddbc08cc3a188 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -46,7 +46,15 @@ void populateConvVectorizationPatterns(
MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
ArrayRef<int64_t> tileSizes);
-/// Populates patterns for vectorizing convolution ops.
+/// Populates patterns to decompose high-D convolution ops into low-D ones. This
+/// is a step in progressive lowering for convolution ops, afterwards we can
+/// vectorize the low-D convolution ops.
+void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
+/// Populates patterns for vectorizing low-D convolution ops. This is a step in
+/// progressive lowering for convolution ops, it assume high-D convolution ops
+/// were decomposed previously.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index f81ce919a4faf..5e9073678015b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -840,3 +840,98 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
return success();
}
+
+namespace {
+// The following are patterns for downscaling convolution ops with size-1
+// window dimensions.
+//
+// Note that we'd eventually want to write such transformations in a generic
+// way, e.g., converting to linalg.generic, removing the size-1 dimensions,
+// and then turning back to named ops. But for now it's fine to have a few
+// patterns matching special ops to get started.
+
+/// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
+/// convolution ops.
+struct DownscaleSizeOneWindowed2DConvolution final
+ : public OpRewritePattern<Conv2DNhwcHwcfOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
+ PatternRewriter &rewriter) const override {
+ auto linalgOp = cast<linalg::LinalgOp>(*convOp);
+ if (linalgOp.hasBufferSemantics())
+ return failure(); // To be implemented
+
+ Value input = convOp.inputs().front();
+ Value filter = convOp.inputs().back();
+ Value output = convOp.outputs().front();
+
+ auto inputType = input.getType().dyn_cast<RankedTensorType>();
+ auto filterType = filter.getType().dyn_cast<RankedTensorType>();
+ auto outputType = output.getType().dyn_cast<RankedTensorType>();
+
+ auto inputShape = inputType.getShape();
+ auto filterShape = filterType.getShape();
+ auto outputShape = outputType.getShape();
+
+ // Only handle the case where at least one of the window dimensions is
+ // of size 1. Other cases can rely on tiling to reduce to such cases.
+ int64_t fhSize = filterShape[0], fwSize = filterShape[1];
+ int64_t ohSize = outputShape[1], owSize = outputShape[2];
+ if (!(fhSize == 1 && ohSize == 1) && !(fwSize == 1 && owSize == 1))
+ return failure();
+ bool removeH = ohSize == 1;
+
+ // Get new shapes and types for all operands by removing the size-1
+ // dimension.
+
+ SmallVector<int64_t, 3> newInputShape{
+ inputShape[0], inputShape[removeH ? 2 : 1], inputShape[3]};
+ auto newInputType = RankedTensorType::get(
+ newInputShape, inputType.getElementType(), inputType.getEncoding());
+
+ SmallVector<int64_t, 3> newFilterShape{filterShape[removeH ? 1 : 0],
+ filterShape[2], filterShape[3]};
+ auto newFilterType = RankedTensorType::get(
+ newFilterShape, filterType.getElementType(), filterType.getEncoding());
+
+ SmallVector<int64_t, 3> newOutputShape{
+ outputShape[0], outputShape[removeH ? 2 : 1], outputShape[3]};
+ auto newOutputType = RankedTensorType::get(
+ newOutputShape, outputType.getElementType(), outputType.getEncoding());
+
+ SmallVector<ReassociationIndices, 3> ioReshapeIndices = {{0}, {1, 2}, {3}};
+ SmallVector<ReassociationIndices, 3> fReshapeIndices = {{0, 1}, {2}, {3}};
+
+ // Reshape all operands for 1-D convolution.
+ Location loc = convOp.getLoc();
+ Value newInput = rewriter.create<linalg::TensorCollapseShapeOp>(
+ loc, newInputType, input, ioReshapeIndices);
+ Value newFilter = rewriter.create<linalg::TensorCollapseShapeOp>(
+ loc, newFilterType, filter, fReshapeIndices);
+ Value newOutput = rewriter.create<linalg::TensorCollapseShapeOp>(
+ loc, newOutputType, output, ioReshapeIndices);
+
+ // We need to shrink the strides and dilations too.
+ auto stride = convOp.strides().getFlatValue<int64_t>(removeH ? 1 : 0);
+ auto stridesAttr = rewriter.getI64VectorAttr(stride);
+ auto dilation = convOp.dilations().getFlatValue<int64_t>(removeH ? 1 : 0);
+ auto dilationsAttr = rewriter.getI64VectorAttr(dilation);
+
+ auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
+ loc, newOutputType, ValueRange{newInput, newFilter},
+ ValueRange{newOutput}, stridesAttr, dilationsAttr);
+
+ rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
+ convOp, outputType, conv1DOp.getResult(0), ioReshapeIndices);
+ return success();
+ };
+};
+
+} // namespace
+
+void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit) {
+ patterns.add<DownscaleSizeOneWindowed2DConvolution>(patterns.getContext(),
+ benefit);
+}
diff --git a/mlir/test/Dialect/Linalg/decompose-convolution.mlir b/mlir/test/Dialect/Linalg/decompose-convolution.mlir
new file mode 100644
index 0000000000000..ebd7dd6d4a2af
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/decompose-convolution.mlir
@@ -0,0 +1,67 @@
+// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-decompose-convolution-patterns %s | FileCheck %s
+
+// CHECK-LABEL: func @conv2d_nhwc_4x1x2x8_tensor
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<4x1x6x3xf32>, %[[FILTER:.+]]: tensor<1x2x3x8xf32>, %[[INIT:.+]]: tensor<4x1x2x8xf32>)
+func @conv2d_nhwc_4x1x2x8_tensor(%input: tensor<4x1x6x3xf32>, %filter: tensor<1x2x3x8xf32>, %init: tensor<4x1x2x8xf32>) -> tensor<4x1x2x8xf32> {
+ %0 = linalg.conv_2d_nhwc_hwcf
+ {dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<[3, 2]> : tensor<2xi64>}
+ ins(%input, %filter : tensor<4x1x6x3xf32>, tensor<1x2x3x8xf32>)
+ outs(%init : tensor<4x1x2x8xf32>) -> tensor<4x1x2x8xf32>
+ return %0 : tensor<4x1x2x8xf32>
+}
+
+// CHECK: %[[INPUT_1D:.+]] = linalg.tensor_collapse_shape %[[INPUT]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor<4x1x6x3xf32> into tensor<4x6x3xf32>
+// CHECK: %[[FILTER_1D:.+]] = linalg.tensor_collapse_shape %[[FILTER]]
+// CHECK-SAME{LITERAL}: [[0, 1], [2], [3]] : tensor<1x2x3x8xf32> into tensor<2x3x8xf32>
+// CHECK: %[[INIT_1D:.+]] = linalg.tensor_collapse_shape %[[INIT]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor<4x1x2x8xf32> into tensor<4x2x8xf32>
+// CHECK: %[[CONV_1D:.+]] = linalg.conv_1d_nwc_wcf
+// CHECK-SAME: dilations = dense<3> : vector<1xi64>
+// CHECK-SAME: strides = dense<2> : vector<1xi64>
+// CHECK-SAME: ins(%[[INPUT_1D]], %[[FILTER_1D]] : tensor<4x6x3xf32>, tensor<2x3x8xf32>)
+// CHECK-SAME: outs(%[[INIT_1D]] : tensor<4x2x8xf32>)
+// CHECK: %[[CONV_2D:.+]] = linalg.tensor_expand_shape %[[CONV_1D]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor<4x2x8xf32> into tensor<4x1x2x8xf32>
+// CHECK: return %[[CONV_2D]]
+
+// -----
+
+// CHECK-LABEL: func @conv2d_nhwc_qxqx1xq_tensor
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<?x?x1x?xf32>, %[[FILTER:.+]]: tensor<?x1x?x?xf32>, %[[INIT:.+]]: tensor<?x?x1x?xf32>)
+func @conv2d_nhwc_qxqx1xq_tensor(%input: tensor<?x?x1x?xf32>, %filter: tensor<?x1x?x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
+ %0 = linalg.conv_2d_nhwc_hwcf
+ {dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<[3, 2]> : tensor<2xi64>}
+ ins(%input, %filter : tensor<?x?x1x?xf32>, tensor<?x1x?x?xf32>)
+ outs(%init : tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
+ return %0 : tensor<?x?x1x?xf32>
+}
+
+// CHECK: %[[INPUT_1D:.+]] = linalg.tensor_collapse_shape %[[INPUT]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor<?x?x1x?xf32> into tensor<?x?x?xf32>
+// CHECK: %[[FILTER_1D:.+]] = linalg.tensor_collapse_shape %[[FILTER]]
+// CHECK-SAME{LITERAL}: [[0, 1], [2], [3]] : tensor<?x1x?x?xf32> into tensor<?x?x?xf32>
+// CHECK: %[[INIT_1D:.+]] = linalg.tensor_collapse_shape %[[INIT]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor<?x?x1x?xf32> into tensor<?x?x?xf32>
+// CHECK: %[[CONV_1D:.+]] = linalg.conv_1d_nwc_wcf
+// CHECK-SAME: dilations = dense<2> : vector<1xi64>
+// CHECK-SAME: strides = dense<3> : vector<1xi64>
+// CHECK-SAME: ins(%[[INPUT_1D]], %[[FILTER_1D]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+// CHECK-SAME: outs(%[[INIT_1D]] : tensor<?x?x?xf32>)
+// CHECK: %[[CONV_2D:.+]] = linalg.tensor_expand_shape %[[CONV_1D]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor<?x?x?xf32> into tensor<?x?x1x?xf32>
+// CHECK: return %[[CONV_2D]]
+
+// -----
+
+// Do not convert convolution ops whose window dimensions are not ones.
+
+// CHECK-LABEL: func @conv2d_nhwc_4x1x2x8_tensor
+func @conv2d_nhwc_4x1x2x8_tensor(%input: tensor<4x3x5x3xf32>, %filter: tensor<2x2x3x8xf32>, %init: tensor<4x1x2x8xf32>) -> tensor<4x1x2x8xf32> {
+ // CHECK: linalg.conv_2d_nhwc_hwcf
+ %0 = linalg.conv_2d_nhwc_hwcf
+ {dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%input, %filter : tensor<4x3x5x3xf32>, tensor<2x2x3x8xf32>)
+ outs(%init : tensor<4x1x2x8xf32>) -> tensor<4x1x2x8xf32>
+ return %0 : tensor<4x1x2x8xf32>
+}
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 25525fb851d24..78711310a8d88 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -152,6 +152,11 @@ struct TestLinalgTransforms
llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
"tiled_loop"),
llvm::cl::init("for")};
+ Option<bool> testDecomposeConvolutionPattern{
+ *this, "test-decompose-convolution-patterns",
+ llvm::cl::desc("Test a set of patterns to rewrite high-D convolution ops "
+ "into low-D ones"),
+ llvm::cl::init(false)};
};
} // end anonymous namespace
@@ -576,6 +581,12 @@ static void applyLinalgToVectorPatterns(FuncOp funcOp) {
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
+static void applyDecomposeConvolutionPatterns(FuncOp funcOp) {
+ RewritePatternSet patterns(funcOp.getContext());
+ populateDecomposeConvolutionPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext());
@@ -819,6 +830,8 @@ void TestLinalgTransforms::runOnFunction() {
return applyPadPattern(getFunction(), packPaddings, hoistPaddings);
if (testInterchangePattern.hasValue())
return applyInterchangePattern(getFunction(), testInterchangePattern);
+ if (testDecomposeConvolutionPattern)
+ return applyDecomposeConvolutionPatterns(getFunction());
}
namespace mlir {
More information about the Mlir-commits
mailing list