[Mlir-commits] [mlir] 991945f - [mlir][linalg] Downscale 2D convolution with unit dimensions to 1D convolution

Hanhan Wang llvmlistbot at llvm.org
Wed Mar 8 14:32:04 PST 2023


Author: Devajith Valaparambil Sreeramaswamy
Date: 2023-03-08T14:31:54-08:00
New Revision: 991945f4410af9df33f0889bf3c0695fd45a28b1

URL: https://github.com/llvm/llvm-project/commit/991945f4410af9df33f0889bf3c0695fd45a28b1
DIFF: https://github.com/llvm/llvm-project/commit/991945f4410af9df33f0889bf3c0695fd45a28b1.diff

LOG: [mlir][linalg] Downscale 2D convolution with unit dimensions to 1D convolution

Decompose conv_2d -> conv_1d.

This MR follows a similar approach to https://reviews.llvm.org/D112928.

This patch adds support to convert conv_2D operation with either unit height or unit width to conv_1D operation.

This is useful when 2D convolution is tiled to have a single dimension for either height or width and then can be vectorized once it is decomposed into 1D convolution.

This patch https://reviews.llvm.org/D145160 adds vector support for linalg.conv_1d operation and thereby allowing us to vectorize linalg.conv_2d operation after proper tiling.

This missing feature is reported here: https://discourse.llvm.org/t/vectorization-of-convolution-op/60458.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D145162

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/transform-op-decompose.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 4dd7641e8193b..eaf9fec6e5cf5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1041,6 +1041,19 @@ struct DownscaleDepthwiseConv2DNhwcHwcOp final
   }
 };
 
+struct DownscaleConv2DOp final : public OpRewritePattern<Conv2DOp> {
+  DownscaleConv2DOp(MLIRContext *context, PatternBenefit benefit = 1)
+      : OpRewritePattern<Conv2DOp>(context, benefit) {}
+
+  FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp,
+                                               PatternRewriter &rewriter) const;
+
+  LogicalResult matchAndRewrite(Conv2DOp convOp,
+                                PatternRewriter &rewriter) const override {
+    return returningMatchAndRewrite(convOp, rewriter);
+  }
+};
+
 ///
 /// Linalg generalization pattern.
 ///

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 600cddec47d17..3e6e1dfdc381d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -266,6 +266,7 @@ transform::DecomposeOp::applyToOne(LinalgOp target,
   DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
   DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp)
   DOWNSCALE(DownscaleDepthwiseConv2DNhwcHwcOp)
+  DOWNSCALE(DownscaleConv2DOp)
 #undef DOWNSCALE_NORMAL
 #undef DOWNSCALE_CALL
 #undef DOWNSCALE

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 01f2c17c39cda..9de0f763d3292 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1361,14 +1361,71 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
   return conv1DOp;
 }
 
+FailureOr<Conv1DOp>
+DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp,
+                                            PatternRewriter &rewriter) const {
+  if (convOp.hasBufferSemantics())
+    return failure(); // To be implemented.
+
+  Value input = convOp.getInputs().front();
+  Value kernel = convOp.getInputs().back();
+  Value output = convOp.getOutputs().front();
+
+  auto inputType = input.getType().dyn_cast<RankedTensorType>();
+  auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
+  auto outputType = output.getType().dyn_cast<RankedTensorType>();
+
+  auto kernelShape = kernelType.getShape();
+  auto outputShape = outputType.getShape();
+
+  // Only handle the case where at least one of the window dimensions is
+  // of size 1. Other cases can rely on tiling to reduce to such cases.
+  int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
+  int64_t ohSize = outputShape[0], owSize = outputShape[1];
+  bool removeH = (khSize == 1 && ohSize == 1);
+  bool removeW = (kwSize == 1 && owSize == 1);
+  if (!removeH && !removeW)
+    return failure();
+
+  // Get new shapes and types for all operands by removing the size-1
+  // dimension.
+  using RTTBuilder = RankedTensorType::Builder;
+  RankedTensorType newInputType =
+      RTTBuilder(inputType).dropDim((removeH ? 0 : 1));
+  RankedTensorType newKernelType =
+      RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
+  RankedTensorType newOutputType =
+      RTTBuilder(outputType).dropDim(removeH ? 0 : 1);
+
+  // Rank-reduce operands.
+  Location loc = convOp.getLoc();
+  Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
+      rewriter, loc, input, newInputType);
+  Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
+      rewriter, loc, kernel, newKernelType);
+  Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
+      rewriter, loc, output, newOutputType);
+
+  auto conv1DOp = rewriter.create<Conv1DOp>(loc, newOutputType,
+                                            ValueRange{newInput, newKernel},
+                                            ValueRange{newOutput});
+
+  // Insert back.
+  Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
+      rewriter, loc, conv1DOp.getResult(0), output);
+  rewriter.replaceOp(convOp, inserted);
+
+  return conv1DOp;
+}
+
 void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
                                                   PatternBenefit benefit) {
   patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
                                                      Conv1DNwcWcfOp>,
                DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
                                                      Conv1DNcwFcwOp>,
-               DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(),
-                                                  benefit);
+               DownscaleDepthwiseConv2DNhwcHwcOp, DownscaleConv2DOp>(
+      patterns.getContext(), benefit);
   patterns.add<
       DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp, PoolingNwcSumOp>,
       DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp, PoolingNcwSumOp>,

diff  --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index e023e64c7cc8b..82795ec9d4bf8 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -56,6 +56,23 @@ func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<1x1x113x96xf32>, %filter: t
   return %0: tensor<1x1x56x96xf32>
 }
 
+// CHECK-LABEL: @conv_2d
+// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<1x?xf32>,
+// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
+// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<1x?xf32>)
+func.func @conv_2d(%input: tensor<1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<1x?xf32>) -> tensor<1x?xf32> {
+  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+  // CHECK: %[[SLICERES:.+]] = linalg.conv_1d
+  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  %0 = linalg.conv_2d
+     ins (%input, %filter: tensor<1x?xf32>, tensor<1x?xf32>)
+    outs (%init: tensor<1x?xf32>) -> tensor<1x?xf32>
+  // CHECK: return %[[RES]]
+  return %0 : tensor<1x?xf32>
+}
+
 // CHECK-LABEL: @pooling_nhwc_sum
 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>


        


More information about the Mlir-commits mailing list