[Mlir-commits] [mlir] 7b615a8 - [mlir][linalg] Rewrite `linalg.conv_2d_nhwc_hwcf` into 1-D

Lei Zhang llvmlistbot at llvm.org
Tue Nov 2 07:00:54 PDT 2021


Author: Lei Zhang
Date: 2021-11-02T09:56:26-04:00
New Revision: 7b615a87dc559e24ef377313e00fe339b96c893e

URL: https://github.com/llvm/llvm-project/commit/7b615a87dc559e24ef377313e00fe339b96c893e
DIFF: https://github.com/llvm/llvm-project/commit/7b615a87dc559e24ef377313e00fe339b96c893e.diff

LOG: [mlir][linalg] Rewrite `linalg.conv_2d_nhwc_hwcf` into 1-D

We'd like to take a progressive approach towards Fconvolution op
CodeGen, by 1) tiling it to fit compute hierarchy first, and then
2) tiling along window dimensions with size 1 to reduce the problem
to be matmul-like. After that, we can 3) downscale high-D convolution
ops to low-D by removing the size-1 window dimensions. The final
step would be 4) vectorizing the low-D convolution op directly.

We have patterns for 1), 2), and 4). This commit adds a pattern for
3) for `linalg.conv_2d_nhwc_hwcf` ops as a starter. Supporting other
high-D convolution ops should be similar and mechanical.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D112928

Added: 
    mlir/test/Dialect/Linalg/decompose-convolution.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index bbaabac10cdb0..ddbc08cc3a188 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -46,7 +46,15 @@ void populateConvVectorizationPatterns(
     MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
     ArrayRef<int64_t> tileSizes);
 
-/// Populates patterns for vectorizing convolution ops.
+/// Populates patterns to decompose high-D convolution ops into low-D ones. This
+/// is a step in progressive lowering for convolution ops, afterwards we can
+/// vectorize the low-D convolution ops.
+void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
+                                          PatternBenefit benefit = 1);
+
+/// Populates patterns for vectorizing low-D convolution ops. This is a step in
+/// progressive lowering for convolution ops, it assume high-D convolution ops
+/// were decomposed previously.
 void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns,
                                               PatternBenefit benefit = 1);
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index f81ce919a4faf..5e9073678015b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -840,3 +840,98 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
   rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
   return success();
 }
+
+namespace {
+// The following are patterns for downscaling convolution ops with size-1
+// window dimensions.
+//
+// Note that we'd eventually want to write such transformations in a generic
+// way, e.g., converting to linalg.generic, removing the size-1 dimensions,
+// and then turning back to named ops. But for now it's fine to have a few
+// patterns matching special ops to get started.
+
+/// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
+/// convolution ops.
+struct DownscaleSizeOneWindowed2DConvolution final
+    : public OpRewritePattern<Conv2DNhwcHwcfOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
+                                PatternRewriter &rewriter) const override {
+    auto linalgOp = cast<linalg::LinalgOp>(*convOp);
+    if (linalgOp.hasBufferSemantics())
+      return failure(); // To be implemented
+
+    Value input = convOp.inputs().front();
+    Value filter = convOp.inputs().back();
+    Value output = convOp.outputs().front();
+
+    auto inputType = input.getType().dyn_cast<RankedTensorType>();
+    auto filterType = filter.getType().dyn_cast<RankedTensorType>();
+    auto outputType = output.getType().dyn_cast<RankedTensorType>();
+
+    auto inputShape = inputType.getShape();
+    auto filterShape = filterType.getShape();
+    auto outputShape = outputType.getShape();
+
+    // Only handle the case where at least one of the window dimensions is
+    // of size 1. Other cases can rely on tiling to reduce to such cases.
+    int64_t fhSize = filterShape[0], fwSize = filterShape[1];
+    int64_t ohSize = outputShape[1], owSize = outputShape[2];
+    if (!(fhSize == 1 && ohSize == 1) && !(fwSize == 1 && owSize == 1))
+      return failure();
+    bool removeH = ohSize == 1;
+
+    // Get new shapes and types for all operands by removing the size-1
+    // dimension.
+
+    SmallVector<int64_t, 3> newInputShape{
+        inputShape[0], inputShape[removeH ? 2 : 1], inputShape[3]};
+    auto newInputType = RankedTensorType::get(
+        newInputShape, inputType.getElementType(), inputType.getEncoding());
+
+    SmallVector<int64_t, 3> newFilterShape{filterShape[removeH ? 1 : 0],
+                                           filterShape[2], filterShape[3]};
+    auto newFilterType = RankedTensorType::get(
+        newFilterShape, filterType.getElementType(), filterType.getEncoding());
+
+    SmallVector<int64_t, 3> newOutputShape{
+        outputShape[0], outputShape[removeH ? 2 : 1], outputShape[3]};
+    auto newOutputType = RankedTensorType::get(
+        newOutputShape, outputType.getElementType(), outputType.getEncoding());
+
+    SmallVector<ReassociationIndices, 3> ioReshapeIndices = {{0}, {1, 2}, {3}};
+    SmallVector<ReassociationIndices, 3> fReshapeIndices = {{0, 1}, {2}, {3}};
+
+    // Reshape all operands for 1-D convolution.
+    Location loc = convOp.getLoc();
+    Value newInput = rewriter.create<linalg::TensorCollapseShapeOp>(
+        loc, newInputType, input, ioReshapeIndices);
+    Value newFilter = rewriter.create<linalg::TensorCollapseShapeOp>(
+        loc, newFilterType, filter, fReshapeIndices);
+    Value newOutput = rewriter.create<linalg::TensorCollapseShapeOp>(
+        loc, newOutputType, output, ioReshapeIndices);
+
+    // We need to shrink the strides and dilations too.
+    auto stride = convOp.strides().getFlatValue<int64_t>(removeH ? 1 : 0);
+    auto stridesAttr = rewriter.getI64VectorAttr(stride);
+    auto dilation = convOp.dilations().getFlatValue<int64_t>(removeH ? 1 : 0);
+    auto dilationsAttr = rewriter.getI64VectorAttr(dilation);
+
+    auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
+        loc, newOutputType, ValueRange{newInput, newFilter},
+        ValueRange{newOutput}, stridesAttr, dilationsAttr);
+
+    rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
+        convOp, outputType, conv1DOp.getResult(0), ioReshapeIndices);
+    return success();
+  };
+};
+
+} // namespace
+
+void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
+                                                  PatternBenefit benefit) {
+  patterns.add<DownscaleSizeOneWindowed2DConvolution>(patterns.getContext(),
+                                                      benefit);
+}

diff  --git a/mlir/test/Dialect/Linalg/decompose-convolution.mlir b/mlir/test/Dialect/Linalg/decompose-convolution.mlir
new file mode 100644
index 0000000000000..ebd7dd6d4a2af
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/decompose-convolution.mlir
@@ -0,0 +1,67 @@
+// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-decompose-convolution-patterns %s | FileCheck %s
+
+// CHECK-LABEL: func @conv2d_nhwc_4x1x2x8_tensor
+//  CHECK-SAME: (%[[INPUT:.+]]: tensor<4x1x6x3xf32>, %[[FILTER:.+]]: tensor<1x2x3x8xf32>, %[[INIT:.+]]: tensor<4x1x2x8xf32>)
+func @conv2d_nhwc_4x1x2x8_tensor(%input: tensor<4x1x6x3xf32>, %filter: tensor<1x2x3x8xf32>, %init: tensor<4x1x2x8xf32>) -> tensor<4x1x2x8xf32> {
+  %0 = linalg.conv_2d_nhwc_hwcf
+    {dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<[3, 2]> : tensor<2xi64>}
+    ins(%input, %filter : tensor<4x1x6x3xf32>, tensor<1x2x3x8xf32>)
+    outs(%init : tensor<4x1x2x8xf32>) -> tensor<4x1x2x8xf32>
+  return %0 : tensor<4x1x2x8xf32>
+}
+
+//               CHECK: %[[INPUT_1D:.+]] = linalg.tensor_collapse_shape %[[INPUT]]
+// CHECK-SAME{LITERAL}:   [[0], [1, 2], [3]] : tensor<4x1x6x3xf32> into tensor<4x6x3xf32>
+//               CHECK: %[[FILTER_1D:.+]] = linalg.tensor_collapse_shape %[[FILTER]]
+// CHECK-SAME{LITERAL}:   [[0, 1], [2], [3]] : tensor<1x2x3x8xf32> into tensor<2x3x8xf32>
+//               CHECK: %[[INIT_1D:.+]] = linalg.tensor_collapse_shape %[[INIT]]
+// CHECK-SAME{LITERAL}:   [[0], [1, 2], [3]] : tensor<4x1x2x8xf32> into tensor<4x2x8xf32>
+//               CHECK: %[[CONV_1D:.+]] = linalg.conv_1d_nwc_wcf
+//          CHECK-SAME:     dilations = dense<3> : vector<1xi64>
+//          CHECK-SAME:     strides = dense<2> : vector<1xi64>
+//          CHECK-SAME:   ins(%[[INPUT_1D]], %[[FILTER_1D]] : tensor<4x6x3xf32>, tensor<2x3x8xf32>)
+//          CHECK-SAME:   outs(%[[INIT_1D]] : tensor<4x2x8xf32>)
+//               CHECK: %[[CONV_2D:.+]] = linalg.tensor_expand_shape %[[CONV_1D]]
+// CHECK-SAME{LITERAL}:   [[0], [1, 2], [3]] : tensor<4x2x8xf32> into tensor<4x1x2x8xf32>
+//               CHECK: return %[[CONV_2D]]
+
+// -----
+
+// CHECK-LABEL: func @conv2d_nhwc_qxqx1xq_tensor
+//  CHECK-SAME: (%[[INPUT:.+]]: tensor<?x?x1x?xf32>, %[[FILTER:.+]]: tensor<?x1x?x?xf32>, %[[INIT:.+]]: tensor<?x?x1x?xf32>)
+func @conv2d_nhwc_qxqx1xq_tensor(%input: tensor<?x?x1x?xf32>, %filter: tensor<?x1x?x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
+  %0 = linalg.conv_2d_nhwc_hwcf
+    {dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<[3, 2]> : tensor<2xi64>}
+    ins(%input, %filter : tensor<?x?x1x?xf32>, tensor<?x1x?x?xf32>)
+    outs(%init : tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
+  return %0 : tensor<?x?x1x?xf32>
+}
+
+//               CHECK: %[[INPUT_1D:.+]] = linalg.tensor_collapse_shape %[[INPUT]]
+// CHECK-SAME{LITERAL}:   [[0], [1, 2], [3]] : tensor<?x?x1x?xf32> into tensor<?x?x?xf32>
+//               CHECK: %[[FILTER_1D:.+]] = linalg.tensor_collapse_shape %[[FILTER]]
+// CHECK-SAME{LITERAL}:   [[0, 1], [2], [3]] : tensor<?x1x?x?xf32> into tensor<?x?x?xf32>
+//               CHECK: %[[INIT_1D:.+]] = linalg.tensor_collapse_shape %[[INIT]]
+// CHECK-SAME{LITERAL}:   [[0], [1, 2], [3]] : tensor<?x?x1x?xf32> into tensor<?x?x?xf32>
+//               CHECK: %[[CONV_1D:.+]] = linalg.conv_1d_nwc_wcf
+//          CHECK-SAME:     dilations = dense<2> : vector<1xi64>
+//          CHECK-SAME:     strides = dense<3> : vector<1xi64>
+//          CHECK-SAME:   ins(%[[INPUT_1D]], %[[FILTER_1D]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+//          CHECK-SAME:   outs(%[[INIT_1D]] : tensor<?x?x?xf32>)
+//               CHECK: %[[CONV_2D:.+]] = linalg.tensor_expand_shape %[[CONV_1D]]
+// CHECK-SAME{LITERAL}:   [[0], [1, 2], [3]] : tensor<?x?x?xf32> into tensor<?x?x1x?xf32>
+//               CHECK: return %[[CONV_2D]]
+
+// -----
+
+// Do not convert convolution ops whose window dimensions are not ones.
+
+// CHECK-LABEL: func @conv2d_nhwc_4x1x2x8_tensor
+func @conv2d_nhwc_4x1x2x8_tensor(%input: tensor<4x3x5x3xf32>, %filter: tensor<2x2x3x8xf32>, %init: tensor<4x1x2x8xf32>) -> tensor<4x1x2x8xf32> {
+  // CHECK: linalg.conv_2d_nhwc_hwcf
+  %0 = linalg.conv_2d_nhwc_hwcf
+    {dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+    ins(%input, %filter : tensor<4x3x5x3xf32>, tensor<2x2x3x8xf32>)
+    outs(%init : tensor<4x1x2x8xf32>) -> tensor<4x1x2x8xf32>
+  return %0 : tensor<4x1x2x8xf32>
+}

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 25525fb851d24..78711310a8d88 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -152,6 +152,11 @@ struct TestLinalgTransforms
       llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
                      "tiled_loop"),
       llvm::cl::init("for")};
+  Option<bool> testDecomposeConvolutionPattern{
+      *this, "test-decompose-convolution-patterns",
+      llvm::cl::desc("Test a set of patterns to rewrite high-D convolution ops "
+                     "into low-D ones"),
+      llvm::cl::init(false)};
 };
 } // end anonymous namespace
 
@@ -576,6 +581,12 @@ static void applyLinalgToVectorPatterns(FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
+static void applyDecomposeConvolutionPatterns(FuncOp funcOp) {
+  RewritePatternSet patterns(funcOp.getContext());
+  populateDecomposeConvolutionPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
 static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
   RewritePatternSet patterns(funcOp.getContext());
   patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext());
@@ -819,6 +830,8 @@ void TestLinalgTransforms::runOnFunction() {
     return applyPadPattern(getFunction(), packPaddings, hoistPaddings);
   if (testInterchangePattern.hasValue())
     return applyInterchangePattern(getFunction(), testInterchangePattern);
+  if (testDecomposeConvolutionPattern)
+    return applyDecomposeConvolutionPatterns(getFunction());
 }
 
 namespace mlir {


        


More information about the Mlir-commits mailing list